Train Decoder
Training script for the ViT decoder with reconstruction losses
PreEncodedDataset
def PreEncodedDataset(
encoded_dir
):
Load pre-encoded embeddings + images from .pt files
setup_dataloaders
def setup_dataloaders(
cfg, preencoded:bool=False
):
setup_models
def setup_models(
cfg, device, preencoded, verbose:bool=True
):
setup_tstate
def setup_tstate(
cfg, device, decoder, encoder:NoneType=None
):
Training_state: Losses, Optimizers, Schedulers, AMP Scalers
get_embeddings_batch
def get_embeddings_batch(
batch, encoder:NoneType=None, preencoded:bool=False, device:str='cuda', allow_grad:bool=False
):
train_step
def train_step(
epoch, enc_out, img_real, decoder, tstate, # named tuple containing optimizers, loss fns, scalers
cfg, # config
note_weights:NoneType=None
):
training step for decoder
train
def train(
cfg:DictConfig
):