utils

utilities

to_scalar


def to_scalar(
    x
):

When you’re sick of guessing what you’re allowed to do to make it log to WandB…


binarize


def binarize(
    img, thresh:float=0.5
):

used to turn images on [0.1] into binary floats. Leave thresh the same for ALL parts of code, for standardization

a = torch.randn(10,)
print(a)
print(binarize(a))
print(binarize(a,0.1))
tensor([-1.1345,  0.1578,  0.3754,  0.3470, -0.3611,  1.7399, -0.6777, -0.4291,
         0.1035,  0.3489])
tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])
tensor([0., 1., 1., 1., 0., 1., 0., 0., 1., 1.])

freemem


def freemem(
    
):

free up unused memory


crevert


def crevert(
    
):

reverts color to default


cchange


def cchange(
    color
):

just changes color for whatever comes next


rprint


def rprint(
    text, color:NoneType=None, width:NoneType=None
):

Prints right-justified text. For backwards compatibility


cjprint


def cjprint(
    args:VAR_POSITIONAL, color:NoneType=None, justify:str='left', width:NoneType=None, bright:bool=True,
    revert:bool=True, end:NoneType=None
):

Prints optionally colored and/or justified text.

from midi_rae.core import * 
cjprint(logo, color="red") 
width = max(len(line) for line in logo.splitlines())
rprint("I just met you","yellow", width=width) 
cjprint("Hello here's some text, revert=False",color="magenta", revert=False) 
print("And here's text in the same color") 
cchange('blue') 
print("New text now in blue") 
cjprint("Sending visualization data",color="green",end="")
rprint("But here's some numbers","yellow",width=120)
crevert()
print("Back to default")

          ▬▬    ▬▬▬    ▬▬                                       ▬▬                                  "Scrawl Me Maybe"
                 ▬▬                                                                                                  
 ▬▬  ▬▬  ▬▬▬     ▬▬   ▬▬▬        ▬▬▬▬▬▬ ▬▬▬▬   ▬▬▬▬            ▬▬▬   ▬▬▬▬ ▬▬▬▬▬▬  ▬▬▬▬           ▬▬▬▬▬   ▬▬▬▬  ▬▬▬▬▬ 
 ▬▬▬▬▬▬▬  ▬▬  ▬▬▬▬▬    ▬▬   ▬▬▬▬  ▬▬  ▬▬   ▬▬ ▬▬  ▬▬   ▬▬▬▬     ▬▬  ▬▬  ▬▬ ▬▬  ▬▬    ▬▬   ▬▬▬▬  ▬▬      ▬▬  ▬▬ ▬▬  ▬▬
 ▬▬ ▬ ▬▬  ▬▬ ▬▬  ▬▬    ▬▬         ▬▬    ▬▬▬▬▬ ▬▬▬▬▬▬            ▬▬  ▬▬▬▬▬▬ ▬▬  ▬▬ ▬▬▬▬▬          ▬▬▬▬▬  ▬▬  ▬▬ ▬▬  ▬▬
 ▬▬   ▬▬  ▬▬ ▬▬  ▬▬    ▬▬         ▬▬   ▬▬  ▬▬ ▬▬            ▬▬  ▬▬  ▬▬     ▬▬▬▬▬ ▬▬  ▬▬              ▬▬ ▬▬  ▬▬ ▬▬  ▬▬
 ▬▬   ▬▬ ▬▬▬▬ ▬▬▬▬▬▬  ▬▬▬▬       ▬▬▬▬   ▬▬▬▬▬  ▬▬▬▬▬        ▬▬  ▬▬   ▬▬▬▬▬ ▬▬     ▬▬▬▬▬         ▬▬▬▬▬▬   ▬▬▬▬  ▬▬  ▬▬
                                                             ▬▬▬▬         ▬▬▬▬      
                                                                                                       I just met you
Hello here's some text, revert=False
And here's text in the same color
New text now in blue
Sending visualization data                                                                                                 But here's some numbers
Back to default

set_seed


def set_seed(
    seed:int=42, deterministic:bool=False
):

Set all random seeds for reproducibility


param_count


def param_count(
    model
):

Returns total and drainable parameters in a model


save_checkpoint


def save_checkpoint(
    model, epoch, val_loss, cfg, optimizer:NoneType=None, save_every:int=25, n_keep:int=5, verbose:bool=True,
    tag:str=''
):

Saves new checkpoint, keeps best & the most recent n_keep. Can loop over multiple models. (Saves separate files for each model)


load_checkpoint


def load_checkpoint(
    model, ckpt_path:str, return_all:bool=False, weights_only:bool=False, strict:bool=False
):

loads a model (and maybe other things) from a checkpoint file


load_optimizer_state_partial


def load_optimizer_state_partial(
    optimizer, ckpt_opt, device
):

Load optimizer state for as many params as the checkpoint has, ignoring extras. Because some sometimes shit breaks


EMAModel


def EMAModel(
    model, eta:float=0.99, update_every:int=1, dtype:dtype=torch.bfloat16
):

Exponential moving average wrapper for stable teacher-student training.