train_enc
Encoder training script for midi_rae
Curriculum Learning
CurriculumSchedule
def CurriculumSchedule(
cfg, intervals:NoneType=None, verbose:bool=True
):
Manage curriculum learning for shifts in pitch and time.
Compute Loss On Batch
compute_batch_loss
def compute_batch_loss(
batch, encoder, cfg, global_step, mae_decoder:NoneType=None, ema_encoder:NoneType=None, mep_model:NoneType=None,
debug:bool=False
):
Compute loss and return other exal auxiliary variables (for train or val)
Main Training Loop
train
def train(
cfg:DictConfig
):