train_enc

Encoder training script for midi_rae

Curriculum Learning


CurriculumSchedule


def CurriculumSchedule(
    cfg, intervals:NoneType=None, verbose:bool=True
):

Manage curriculum learning for shifts in pitch and time.

Compute Loss On Batch


compute_batch_loss


def compute_batch_loss(
    batch, encoder, cfg, global_step, mae_decoder:NoneType=None, ema_encoder:NoneType=None, mep_model:NoneType=None,
    debug:bool=False
):

Compute loss and return other exal auxiliary variables (for train or val)

Main Training Loop


train


def train(
    cfg:DictConfig
):

CLI Entry Point