Train Decoder

Training script for the ViT decoder with reconstruction losses

PreEncodedDataset


def PreEncodedDataset(
    encoded_dir
):

Load pre-encoded embeddings + images from .pt files


setup_dataloaders


def setup_dataloaders(
    cfg, preencoded:bool=False
):

setup_models


def setup_models(
    cfg, device, preencoded, verbose:bool=True
):

setup_tstate


def setup_tstate(
    cfg, device, decoder, encoder:NoneType=None
):

Training_state: Losses, Optimizers, Schedulers, AMP Scalers


get_embeddings_batch


def get_embeddings_batch(
    batch, encoder:NoneType=None, preencoded:bool=False, device:str='cuda', allow_grad:bool=False
):

train_step


def train_step(
    epoch, enc_out, img_real, decoder, tstate, # named tuple containing optimizers, loss fns, scalers
    cfg, # config
    note_weights:NoneType=None
):

training step for decoder


train


def train(
    cfg:DictConfig
):