Work in progress for NASH Hackathon, Dec 17, 2021

this is like the 01_td_demo notebook only we use a different dataset and generalize the dataloader a bit

Installs and imports

%pip install -Uqq pip 

# Next line only executes on Colab. Colab users: Please enable GPU in Edit > Notebook settings
! [ -e /content ] && pip install -Uqq fastai git+https://github.com/drscotthawley/fastproaudio.git

# Additional installs for this tutorial
%pip install -q fastai_minima torchsummary pyzenodo3 wandb

# Install micro-tcn and auraloss packages (from source, will take a little while)
%pip install -q wheel --ignore-requires-python git+https://github.com/csteinmetz1/micro-tcn.git  git+https://github.com/csteinmetz1/auraloss

# After this cell finishes, restart the kernel and continue below
from fastai.vision.all import *
from fastai.text.all import *
from fastai.callback.fp16 import *
import wandb
from fastai.callback.wandb import *
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from IPython.display import Audio 
import matplotlib.pyplot as plt
import torchsummary
from fastproaudio.core import *
from pathlib import Path
import glob
import json
import re 

import warnings
# mel-spectrogram plot keeps throwing matplotlib deprecation warnings
warnings.filterwarnings( "ignore", module = "librosa\..*" ) 
data_dir = '/home/shawley/datasets/timeAlignData_mono4_mc/'
#data_dir = '/home/shawley/datasets/pb_verb'
#data_dir = '/home/shawley/datasets/timealign_signaltrain_simple'
path = Path(data_dir) 

fnames_in = sorted(glob.glob(str(path)+'/*/input*'))
fnames_targ = sorted(glob.glob(str(path)+'/*/*targ*'))
ind = np.random.randint(len(fnames_in))   # pick one spot in the list of files
fnames_in[ind], fnames_targ[ind]
('/home/shawley/datasets/timeAlignData_mono4_mc/Val/input_107__1__60.wav',
 '/home/shawley/datasets/timeAlignData_mono4_mc/Val/target_107__1__60.wav')

Input audio

input, sample_rate = torchaudio.load(fnames_in[ind])
print("sample_rate = ",sample_rate)
show_audio(input, sample_rate)
sample_rate =  16000
Shape: (3, 320000), Dtype: torch.float32, Duration: 20.0 s
Max:  1.000,  Min: -1.000, Mean:  0.000, Std Dev:  0.106

Target output audio

target, sr_targ = torchaudio.load(fnames_targ[ind])
show_audio(target, sr_targ)
Shape: (3, 320000), Dtype: torch.float32, Duration: 20.0 s
Max:  1.000,  Min: -1.000, Mean:  0.000, Std Dev:  0.112

Dataset class and Dataloaders

Here we modify Christian's SignalTrainLA2ADataset class from micro-tcn. See his data.py

We'll use original dataset class that Christian made, and then for fastai we'll "pack" params and inputs together. (This will be loading multichannel wav files BTW)

Adding Positional Encoding

ConvNets and/or MLPs don't necessarily have a sense of position, but giving them one can improve performance on various tasks where position -- such as time alignment -- matters. Various models will encode position by adding additional channels, e.g. Transformers use Fourier series. We'll use a simplified scheme that we saw Francois Fleuret use:

def get_positional_input(this_seq_length, max_seq_length=200000, channel_index=0):
    """scheme taken from Francois Flueret's attentiontoy1.py, 
    cf. https://twitter.com/francoisfleuret/status/1263516788479922176"""
    seq_length = max_seq_length
    c = math.ceil(math.log(seq_length) / math.log(2.0))
    positional_input = ((torch.arange(seq_length).unsqueeze(0) // 2**torch.arange(c).unsqueeze(1))%2).float()
    positional_input = positional_input[:, 0:this_seq_length]  # clip it
    #print("this_seq_length, positional_input.shape =",this_seq_length, positional_input.shape)
    if channel_index==1: positional_input = positional_input.unsqueeze(0)
    return positional_input


def add_positional_encoding(input, channel_index=0):
    "adds channels onto the end of input"
    positional_input = get_positional_input(input.shape[-1], channel_index=channel_index)
    return torch.cat( (input, positional_input), dim=channel_index)
pe = get_positional_input(64)
print(f"{pe.shape[0]} PE channels for this sequence length.  (Could be more for longer sequences)")
fig, ax = plt.subplots(nrows=pe.shape[0], figsize=(12,10))
for c in range(pe.shape[0]):
    ax[c].plot(pe[c,:],'o-')
/tmp/ipykernel_1809707/2740461433.py:6: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  positional_input = ((torch.arange(seq_length).unsqueeze(0) // 2**torch.arange(c).unsqueeze(1))%2).float()
18 PE channels for this sequence length.  (Could be more for longer sequences)

...you get the idea

USE_POSITIONAL_ENCODING = True 
REMOVE_CLICK = False

Dataset class

from microtcn.data import SignalTrainLA2ADataset

class SignalTrainLA2ADataset_fastai(SignalTrainLA2ADataset):
    "For fastai's sake, have getitem pack the inputs and params together"
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def __getitem__(self, idx):
        input, target, params = super().__getitem__(idx)
        return torch.cat((input,params),dim=-1), target   # pack input and params together  
# actually we're going to modify Christian's code some so let's rename it...
class TimeAlignDataset(torch.utils.data.Dataset):
    """ SignalTrain LA2A dataset. Source: [10.5281/zenodo.3824876](https://zenodo.org/record/3824876)."""
    def __init__(self, root_dir, subset="train", length=16384, preload=False, half=True, 
                 fraction=1.0, use_soundfile=False, positional_encoding=False):
        """
        Args:
            root_dir (str): Path to the root directory of the SignalTrain dataset.
            subset (str, optional): Pull data either from "train", "val", "test", or "full" subsets. (Default: "train")
            length (int, optional): Number of samples in the returned examples. (Default: 40)
            preload (bool, optional): Read in all data into RAM during init. (Default: False)
            half (bool, optional): Store the float32 audio as float16. (Default: True)
            fraction (float, optional): Fraction of the data to load from the subset. (Default: 1.0)
            use_soundfile (bool, optional): Use the soundfile library to load instead of torchaudio. (Default: False)
        """
        self.root_dir = root_dir
        self.subset = subset
        self.length = length
        self.preload = preload
        self.half = half
        self.fraction = fraction
        self.use_soundfile = use_soundfile
        self.positional_channels, self.positional_input  = 0, None
        if positional_encoding:
            self.positional_input = get_positional_input(length)  # same PE tensor for all time
            self.positional_channels = self.positional_input.shape[1]
            #print("self.positional_input.shape = ",self.positional_input.shape)

        if self.subset == "full":
            self.target_files = glob.glob(os.path.join(self.root_dir, "**", "target_*.wav"))
            self.input_files  = glob.glob(os.path.join(self.root_dir, "**", "input_*.wav"))
        else:
            # get all the target files files in the directory first
            self.target_files = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "target_*.wav"))
            self.input_files  = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "input_*.wav"))

        self.examples = [] 
        self.minutes = 0  # total number of hours of minutes in the subset

        # ensure that the sets are ordered correctlty
        self.target_files.sort()
        self.input_files.sort()

        # get the parameters
        #param_parse_fns = [x.replace("__pb_dist","") for x in self.target_files]
        self.params = [(float(f.split("__")[1].replace(".wav","")), float(f.split("__")[2].replace(".wav",""))) for f in self.target_files]
        #print("self.params = ",self.params)
        
        # loop over files to count total length
        for idx, (tfile, ifile, params) in enumerate(zip(self.target_files, self.input_files, self.params)):

            ifile_id = int(os.path.basename(ifile).split("_")[1])
            tfile_id = int(os.path.basename(tfile).split("_")[1])
            if ifile_id != tfile_id:
                raise RuntimeError(f"Found non-matching file ids: {ifile_id} != {tfile_id}! Check dataset.")

            md = torchaudio.info(tfile)
            num_frames = md.num_frames

            if self.preload:
                sys.stdout.write(f"* Pre-loading... {idx+1:3d}/{len(self.target_files):3d} ...\r")
                sys.stdout.flush()
                input, sr  = self.load(ifile)
                target, sr = self.load(tfile)

                num_frames = int(np.min([input.shape[-1], target.shape[-1]]))
                if input.shape[-1] != target.shape[-1]:
                    print(os.path.basename(ifile), input.shape[-1], os.path.basename(tfile), target.shape[-1])
                    raise RuntimeError("Found potentially corrupt file!")
                    
                if self.positional_input is not None: input = torch.cat((input, self.positional_input), dim=1)

                if self.half:
                    input = input.half()
                    target = target.half()
            else:
                input = None
                target = None

            # create one entry for each patch
            self.file_examples = []
            nmax = (num_frames // self.length)
            assert nmax > 0, f"num_frames = {num_frames} but self.length = {self.length}"
            for n in range(nmax):
                offset = int(n * self.length)
                end = offset + self.length
                #print("idx, params = ",idx,params)
                self.file_examples.append({"idx": idx, 
                                           "target_file" : tfile,
                                           "input_file" : ifile,
                                           "input_audio" : input[:,offset:end] if input is not None else None,
                                           "target_audio" : target[:,offset:end] if input is not None else None,
                                           "params" : params,
                                           "offset": offset,
                                           "frames" : num_frames})

            # add to overall file examples
            self.examples += self.file_examples
        
        # use only a fraction of the subset data if applicable
        if self.subset == "train":
            classes = set([ex['params'] for ex in self.examples])
            print("classes = ",classes)
            n_classes = len(classes) # number of unique compressor configurations
            n_classes = 1 if n_classes==0 else n_classes  # SHH don't care; sick of these errors 
            fraction_examples = int(len(self.examples) * self.fraction)
            n_examples_per_class = int(fraction_examples / n_classes)
            n_min_total = ((self.length * n_examples_per_class * n_classes) / md.sample_rate) / 60 
            n_min_per_class = ((self.length * n_examples_per_class) / md.sample_rate) / 60 
            print(sorted(classes))
            print(f"Total Examples: {len(self.examples)}     Total classes: {n_classes}")
            print(f"Fraction examples: {fraction_examples}    Examples/class: {n_examples_per_class}")
            print(f"Training with {n_min_per_class:0.2f} min per class    Total of {n_min_total:0.2f} min")

            if n_examples_per_class <= 0: 
                raise ValueError(f"Fraction `{self.fraction}` set too low. No examples selected.")

            sampled_examples = []

            for config_class in classes: # select N examples from each class
                class_examples = [ex for ex in self.examples if ex["params"] == config_class]
                example_indices = np.random.randint(0, high=len(class_examples), size=n_examples_per_class)
                class_examples = [class_examples[idx] for idx in example_indices]
                extra_factor = int(1/self.fraction)
                sampled_examples += class_examples * extra_factor

            self.examples = sampled_examples

        self.minutes = ((self.length * len(self.examples)) / md.sample_rate) / 60 

        # we then want to get the input files
        print(f"Located {len(self.examples)} examples totaling {self.minutes:0.2f} min in the {self.subset} subset.")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        if self.preload:
            audio_idx = self.examples[idx]["idx"]
            offset = self.examples[idx]["offset"]
            input = self.examples[idx]["input_audio"]
            target = self.examples[idx]["target_audio"]
        else:
            offset = self.examples[idx]["offset"] 
            input, sr  = torchaudio.load(self.examples[idx]["input_file"], 
                                        num_frames=self.length, 
                                        frame_offset=offset, 
                                        normalize=False)
            #print("input.shape, self.positional_input.shape =",input.shape, self.positional_input.shape)
            if self.positional_input is not None: input = torch.cat((input, self.positional_input), dim=0)

            
            target, sr = torchaudio.load(self.examples[idx]["target_file"], 
                                        num_frames=self.length, 
                                        frame_offset=offset, 
                                        normalize=False)
            input = input[REMOVE_CLICK:,:]
            target = target[REMOVE_CLICK:,:]
            
            if self.half:
                input = input.half()
                target = target.half()

        # at random with p=0.5 flip the phase 
        if np.random.rand() > 0.5:
            input[1:,:-self.positional_channels] = -input[1:,:-self.positional_channels]   # but don't flip the click or PE
            target[1:,:-self.positional_channels] = -target[1:,:-self.positional_channels] 

        # then get the tuple of parameters
        params = torch.tensor(self.examples[idx]["params"]).unsqueeze(0)
        params[:,1] /= 100

        #print(f"Checking: idx = {idx}, input.shape = {input.shape}, target.shape = {target.shape}, params.shape = {params.shape}")
        
        return input, target, params

    def load(self, filename):
        if self.use_soundfile:
            x, sr = sf.read(filename, always_2d=True)
            x = torch.tensor(x.T)
        else:
            x, sr = torchaudio.load(filename, normalize=True) #False), true for pedalboard out
        return x, sr
    
    
    
class TimeAlignDataset_fastai(TimeAlignDataset):
    "For fastai's sake, have getitem pack the inputs and params together"
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def __getitem__(self, idx):
        input, target, params = super().__getitem__(idx)
        #print(f"Checking: idx = {idx}, input.shape = {input.shape}, params.shape = {params.shape}")
        if input.shape[0] > params.shape[0]:  # this is an artifact of our trying to pack things together
            #print("trying to fix...")
            params = torch.tile(params, (input.shape[0],1))
            #print(f"Checking2: idx = {idx}, input.shape = {input.shape}, params.shape = {params.shape}\n")
        return torch.cat((input,params),dim=-1), target   # pack input and params together
SAMPLE_RATE = sample_rate 
class Args(object):  # stand-in for parseargs. these are all micro-tcn defaults
    model_type ='tcn'
    root_dir = str(path)
    preload = False
    sample_rate = SAMPLE_RATE
    shuffle = True
    train_subset = 'train'
    val_subset = 'val'
    train_length = min(int(65536), input.shape[-1])
    train_fraction = 1.0
    eval_length = train_length*2 # 65536  # keep these the same if using positional encoding
    batch_size = 8   # original is 32, my laptop needs smaller, esp. w/o half precision
    num_workers = 4  # 1 for debugging, 4 for normal usage 
    precision = 32  # LEAVE AS 32 FOR NOW: HALF PRECISION (16) NOT WORKING YET -SHH
    n_params = 2
    
args = Args()

# just re-measure input and target sizes in case something changed in the REPL/Juptyer notebook state
input, sample_rate = torchaudio.load(fnames_in[ind])
target, sr_targ = torchaudio.load(fnames_targ[ind])

USER_INPUT_CHANNELS = input.shape[0] - REMOVE_CLICK      # how many were supplied by the user, how many we'll plot
TOTAL_INPUT_CHANNELS = USER_INPUT_CHANNELS
TARGET_OUTPUT_CHANNELS = target.shape[0] - REMOVE_CLICK

if USE_POSITIONAL_ENCODING:
    pe = get_positional_input(args.train_length)
    print("pe.shape =",pe.shape)
    TOTAL_INPUT_CHANNELS = USER_INPUT_CHANNELS + pe.shape[0]    # how many the model will take
    
print(f"USER_INPUT_CHANNELS = {USER_INPUT_CHANNELS}") 
print(f"TOTAL_INPUT_CHANNELS = {TOTAL_INPUT_CHANNELS}")
print(f"TARGET_OUTPUT_CHANNELS = {TARGET_OUTPUT_CHANNELS}")

#if args.precision == 16:  torch.set_default_dtype(torch.float16)

# setup the dataloaders
train_dataset = TimeAlignDataset_fastai(args.root_dir, 
                    subset=args.train_subset, 
                    fraction=args.train_fraction,
                    half=True if args.precision == 16 else False, 
                    preload=args.preload, 
                    length=args.train_length,
                    positional_encoding=USE_POSITIONAL_ENCODING)

train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                    shuffle=args.shuffle,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)

val_dataset = TimeAlignDataset_fastai(args.root_dir, 
                    preload=args.preload,
                    half=True if args.precision == 16 else False,
                    subset=args.val_subset,
                    length=args.eval_length,
                    positional_encoding=USE_POSITIONAL_ENCODING)

val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                    shuffle=False,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)
pe.shape = torch.Size([18, 65536])
USER_INPUT_CHANNELS = 3
TOTAL_INPUT_CHANNELS = 21
TARGET_OUTPUT_CHANNELS = 3
classes =  {(1.0, 60.0)}
[(1.0, 60.0)]
Total Examples: 416     Total classes: 1
Fraction examples: 416    Examples/class: 416
Training with 28.40 min per class    Total of 28.40 min
Located 416 examples totaling 28.40 min in the train subset.
Located 60 examples totaling 8.19 min in the val subset.
/tmp/ipykernel_1809707/2740461433.py:6: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  positional_input = ((torch.arange(seq_length).unsqueeze(0) // 2**torch.arange(c).unsqueeze(1))%2).float()

If the user requested fp16 precision then we need to install NVIDIA apex:

if False and args.precision == 16:
    %pip install -q --disable-pip-version-check --no-cache-dir git+https://github.com/NVIDIA/apex
    from apex.fp16_utils import convert_network

Define the model(s)

Christian defined a lot of models. We'll do the TCN-300 and the LSTM.

#from microtcn.lstm import LSTMModel # actually the LSTM depends on a lot of Lightning stuff, so we'll skip that
from microtcn.utils import center_crop, causal_crop

# this is all exactly Christian's code except one tiny change in "groups=" for self.res in TCNBlock. 

class FiLM(torch.nn.Module):
    def __init__(self, 
                 num_features, 
                 cond_dim):
        super(FiLM, self).__init__()
        self.num_features = num_features
        self.bn = torch.nn.BatchNorm1d(num_features, affine=False)
        self.adaptor = torch.nn.Linear(cond_dim, num_features * 2)

    def forward(self, x, cond):

        cond = self.adaptor(cond)
        g, b = torch.chunk(cond, 2, dim=-1)
        g = g.permute(0,2,1)
        b = b.permute(0,2,1)

        x = self.bn(x)      # apply BatchNorm without affine
        x = (x * g) + b     # then apply conditional affine

        return x

class TCNBlock(torch.nn.Module):
    def __init__(self, 
                in_ch, 
                out_ch, 
                kernel_size=3, 
                padding=0, 
                dilation=1, 
                grouped=False, 
                conditional=False, 
                **kwargs):
        super(TCNBlock, self).__init__()

        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.padding = padding
        self.dilation = dilation
        self.grouped = grouped
        self.conditional = conditional

        groups = out_ch if grouped and (in_ch % out_ch == 0) else 1
        
        self.conv1 = torch.nn.Conv1d(in_ch, 
                                     out_ch, 
                                     kernel_size=kernel_size, 
                                     padding=padding, 
                                     dilation=dilation,
                                     groups=groups,
                                     bias=False)
        #if grouped:
        #    self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1)

        if conditional:
            self.film = FiLM(out_ch, 32)
        else:
            self.bn = torch.nn.BatchNorm1d(out_ch)

        self.relu = torch.nn.PReLU(out_ch)
        print("self.res params: ",in_ch,  out_ch)
        self.res = torch.nn.Conv1d(in_ch, 
                                   out_ch, 
                                   kernel_size=1,
                                   groups=groups, # SHH: this is a change; Christian's original read =in_ch here. 
                                   bias=False)

    def forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
        x_in = x        
        x = self.conv1(x)
        #if self.grouped: # apply pointwise conv
        #    x = self.conv1b(x)
        if p is not None:   # apply FiLM conditioning
            x = self.film(x, p)
        else:
            x = self.bn(x)
        x = self.relu(x)

        x_res = self.res(x_in)
        x = x + center_crop(x_res, x.size(-1))

        return x

class TCNModel(torch.nn.Module):
    """ Temporal convolutional network with conditioning module.

        Args:
            nparams (int): Number of conditioning parameters.
            ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1
            noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1
            nblocks (int): Number of total TCN blocks. Default: 10
            kernel_size (int): Width of the convolutional kernels. Default: 3
            dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
            channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2
            channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
            stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10
            grouped (bool): Use grouped convolutions to reduce the total number of parameters. Default: False
            num_examples (int): Number of evaluation audio examples to log after each epochs. Default: 4
        """
    def __init__(self, 
                 nparams,
                 ninputs=1,
                 noutputs=1,
                 nblocks=10, 
                 kernel_size=3, 
                 dilation_growth=1, 
                 channel_growth=1, 
                 channel_width=32, 
                 stack_size=10,
                 grouped=False,
                 num_examples=4,
                 save_dir=None,
                 **kwargs):
        super(TCNModel, self).__init__()

        self.nparams=nparams
        self.ninputs=ninputs
        self.noutputs=noutputs
        self.nblocks=nblocks
        self.kernel_size=kernel_size
        self.dilation_growth=dilation_growth
        self.channel_growth=channel_growth
        self.channel_width=channel_width
        self.stack_size=stack_size
        self.grouped=grouped
        self.num_examples=num_examples
        self.save_dir=save_dir

        # setup loss functions
        self.l1      = torch.nn.L1Loss()

        print("nparams = ",nparams)
        if self.nparams > 0:
            self.gen = torch.nn.Sequential(
                torch.nn.Linear(nparams, 16),
                torch.nn.ReLU(),
                torch.nn.Linear(16, 32),
                torch.nn.ReLU(),
                torch.nn.Linear(32, 32),
                torch.nn.ReLU()
            )

        self.blocks = torch.nn.ModuleList()
        for n in range(nblocks):
            in_ch = out_ch if n > 0 else ninputs
            
            if self.channel_growth > 1:
                out_ch = in_ch * self.channel_growth 
            else:
                out_ch = self.channel_width

            dilation = self.dilation_growth ** (n % self.stack_size)
            #dilation = dilation_growth
            self.blocks.append(TCNBlock(in_ch, 
                                        out_ch, 
                                        kernel_size=self.kernel_size, 
                                        dilation=dilation,
                                        grouped=self.grouped,
                                        conditional=True if self.nparams > 0 else False))

        self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)

    def forward(self, x, p=None):
        # if parameters present, 
        # compute global conditioning
        #print("TCNModel.forward: x.shape = ",x.shape)
        if p is not None:
            cond = self.gen(p)
        else:
            cond = None

        # iterate over blocks passing conditioning
        for idx, block in enumerate(self.blocks):
            x = block(x, cond)
            if idx == 0:
                skips = x
            else:
                skips = center_crop(skips, x.size(-1))
                skips = skips + x

        return torch.tanh(self.output(x + skips))

    def compute_receptive_field(self):
        """ Compute the receptive field in samples."""
        rf = self.kernel_size
        for n in range(1,self.nblocks):
            dilation = self.dilation_growth ** (n % self.stack_size)
            rf = rf + ((self.kernel_size-1) * dilation)
        return rf
class TCNModel_fastai(TCNModel):
    "For fastai's sake, unpack the inputs and params"
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def forward(self, x, p=None):
        if (p is None) and (self.nparams > 0):  # unpack the params if needed
            assert len(list(x.size())) == 3   # sanity check 
            x, p = x[:,:,0:-self.nparams], x[:,:,-self.nparams:]
            if p.shape[1] > 1:
                p = p[:,0:1,:]  # only need one copy of params, not the stacked copies supplied by DL. 
        return super().forward(x, p=p)
# micro-tcn defines several different model configurations. I just chose one of them. 
train_configs = [
      {"name" : "custom",   # tried messing around with nblocks, dilation_growth and kernel_size, to not much effect
     "model_type" : "tcn",
     "nblocks" : 8,
     "dilation_growth" : 3,
     "kernel_size" : 15,
     "causal" : False,
     "train_fraction" : 1.00,
     "batch_size" : args.batch_size
    }
]

train_configs = [
      {"name" : "custom",   # tried messing around with nblocks, dilation_growth and kernel_size, to not much effect
     "model_type" : "tcn",
     "nblocks" : 8,
     "dilation_growth" : 3,
     "kernel_size" : 15,
     "causal" : False,
     "train_fraction" : 1.00,
     "batch_size" : args.batch_size
    }
]

print(f"USER_INPUT_CHANNELS = {USER_INPUT_CHANNELS}") 
print(f"TOTAL_INPUT_CHANNELS = {TOTAL_INPUT_CHANNELS}")
print(f"TARGET_OUTPUT_CHANNELS = {TARGET_OUTPUT_CHANNELS}")

dict_args = train_configs[0]
dict_args["channel_width"] = 32
dict_args["nparams"] = 2
dict_args["ninputs"] = TOTAL_INPUT_CHANNELS  # number of input channels
dict_args["noutputs"] = TARGET_OUTPUT_CHANNELS  # number of output channels
dict_args["grouped"] = False 

model = TCNModel_fastai(**dict_args)
dtype = torch.float32

rf = model.compute_receptive_field()
print("Receptive field (in samples) = ",rf)
USER_INPUT_CHANNELS = 3
TOTAL_INPUT_CHANNELS = 21
TARGET_OUTPUT_CHANNELS = 3
nparams =  2
self.res params:  21 32
self.res params:  32 32
self.res params:  32 32
self.res params:  32 32
self.res params:  32 32
self.res params:  32 32
self.res params:  32 32
self.res params:  32 32
Receptive field (in samples) =  45921

Let's take a look at the model:

# this summary allows one to compare the original TCNModel with the TCNModel_fastai
if type(model) == TCNModel_fastai:
    torchsummary.summary(model, [(TOTAL_INPUT_CHANNELS,args.train_length)], device="cpu")
else:
    torchsummary.summary(model, [(TOTAL_INPUT_CHANNELS,args.train_length),(1,2)], device="cpu")
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 16]              48
              ReLU-2                [-1, 1, 16]               0
            Linear-3                [-1, 1, 32]             544
              ReLU-4                [-1, 1, 32]               0
            Linear-5                [-1, 1, 32]           1,056
              ReLU-6                [-1, 1, 32]               0
            Conv1d-7            [-1, 32, 65520]          10,080
            Linear-8                [-1, 1, 64]           2,112
       BatchNorm1d-9            [-1, 32, 65520]               0
             FiLM-10            [-1, 32, 65520]               0
            PReLU-11            [-1, 32, 65520]              32
           Conv1d-12            [-1, 32, 65534]             672
         TCNBlock-13            [-1, 32, 65520]               0
           Conv1d-14            [-1, 32, 65478]          15,360
           Linear-15                [-1, 1, 64]           2,112
      BatchNorm1d-16            [-1, 32, 65478]               0
             FiLM-17            [-1, 32, 65478]               0
            PReLU-18            [-1, 32, 65478]              32
           Conv1d-19            [-1, 32, 65520]           1,024
         TCNBlock-20            [-1, 32, 65478]               0
           Conv1d-21            [-1, 32, 65352]          15,360
           Linear-22                [-1, 1, 64]           2,112
      BatchNorm1d-23            [-1, 32, 65352]               0
             FiLM-24            [-1, 32, 65352]               0
            PReLU-25            [-1, 32, 65352]              32
           Conv1d-26            [-1, 32, 65478]           1,024
         TCNBlock-27            [-1, 32, 65352]               0
           Conv1d-28            [-1, 32, 64974]          15,360
           Linear-29                [-1, 1, 64]           2,112
      BatchNorm1d-30            [-1, 32, 64974]               0
             FiLM-31            [-1, 32, 64974]               0
            PReLU-32            [-1, 32, 64974]              32
           Conv1d-33            [-1, 32, 65352]           1,024
         TCNBlock-34            [-1, 32, 64974]               0
           Conv1d-35            [-1, 32, 63840]          15,360
           Linear-36                [-1, 1, 64]           2,112
      BatchNorm1d-37            [-1, 32, 63840]               0
             FiLM-38            [-1, 32, 63840]               0
            PReLU-39            [-1, 32, 63840]              32
           Conv1d-40            [-1, 32, 64974]           1,024
         TCNBlock-41            [-1, 32, 63840]               0
           Conv1d-42            [-1, 32, 60438]          15,360
           Linear-43                [-1, 1, 64]           2,112
      BatchNorm1d-44            [-1, 32, 60438]               0
             FiLM-45            [-1, 32, 60438]               0
            PReLU-46            [-1, 32, 60438]              32
           Conv1d-47            [-1, 32, 63840]           1,024
         TCNBlock-48            [-1, 32, 60438]               0
           Conv1d-49            [-1, 32, 50232]          15,360
           Linear-50                [-1, 1, 64]           2,112
      BatchNorm1d-51            [-1, 32, 50232]               0
             FiLM-52            [-1, 32, 50232]               0
            PReLU-53            [-1, 32, 50232]              32
           Conv1d-54            [-1, 32, 60438]           1,024
         TCNBlock-55            [-1, 32, 50232]               0
           Conv1d-56            [-1, 32, 19614]          15,360
           Linear-57                [-1, 1, 64]           2,112
      BatchNorm1d-58            [-1, 32, 19614]               0
             FiLM-59            [-1, 32, 19614]               0
            PReLU-60            [-1, 32, 19614]              32
           Conv1d-61            [-1, 32, 50232]           1,024
         TCNBlock-62            [-1, 32, 19614]               0
           Conv1d-63             [-1, 3, 19614]              99
================================================================
Total params: 144,339
Trainable params: 144,339
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 5.25
Forward/backward pass size (MB): 678.83
Params size (MB): 0.55
Estimated Total Size (MB): 684.63
----------------------------------------------------------------

Getting the model into fastai form

Zach Mueller made a very helpful fastai_minima package that we'll use, and follow his instructions.

TODO: Zach says I should either use fastai or fastai_minima, not mix them like I'm about to do. But what I have below is the only thing that works right now. ;-)

# I guess we could've imported these up at the top of the notebook...
from torch import optim
from fastai_minima.optimizer import OptimWrapper
#from fastai_minima.learner import Learner  # this doesn't include lr_find()
from fastai.learner import Learner
from fastai_minima.learner import DataLoaders
#from fastai_minima.callback.training_utils import CudaCallback, ProgressCallback # note sure if I need these
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, **kwargs))

dls = DataLoaders(train_dataloader, val_dataloader)

Checking Dataloaders

if args.precision==16: 
    dtype = torch.float16
    model = convert_network(model, torch.float16)

model = model.to('cuda:0')
if type(model) == TCNModel_fastai:
    print("We're using Hawley's modified code")
    packed, targ = dls.one_batch()
    print("After dls: packed.shape, targ.shape =",packed.shape, targ.shape)
    inp, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
    pred = model.forward(packed.to('cuda:0', dtype=dtype))
else:
    print("We're using Christian's version of Dataloader and model")
    inp, targ, params = dls.one_batch()
    pred = model.forward(inp.to('cuda:0',dtype=dtype), p=params.to('cuda:0', dtype=dtype))
print(f"input  = {inp.size()}\ntarget = {targ.size()}\nparams = {params.size()}\npred   = {pred.size()}")

print(f"So you're only going to get {pred.shape[-1]} predictions.")
We're using Hawley's modified code
After dls: packed.shape, targ.shape = torch.Size([8, 21, 65538]) torch.Size([8, 3, 65536])
input  = torch.Size([8, 21, 65536])
target = torch.Size([8, 3, 65536])
params = torch.Size([8, 21, 2])
pred   = torch.Size([8, 3, 19616])
So you're only going to get 19616 predictions.

We can make the pred and target the same length by cropping when we compute the loss:

import auraloss

class Crop_Loss:
    "Crop target size to match preds"
    def __init__(self, axis=-1, causal=False, reduction="mean", train_loss="l1+stft"):
        store_attr()
        self.train_loss = train_loss
        self.l1      = torch.nn.L1Loss()
        self.stft    = auraloss.freq.STFTLoss()

    def __call__(self, pred, targ):
        targ = causal_crop(targ, pred.shape[-1]) if self.causal else center_crop(targ, pred.shape[-1])
        #pred, targ = TensorBase(pred), TensorBase(targ)
        assert pred.shape == targ.shape, f'pred.shape = {pred.shape} but targ.shape = {targ.shape}'
        #return self.loss_func(pred,targ).flatten().mean() if self.reduction == "mean" else loss(pred,targ).flatten().sum()
        l1_loss = self.l1(pred, targ)
        stft_loss = self.stft(pred, targ)
        if self.train_loss=="l1+stft":
            loss = l1_loss + stft_loss
        elif self.train_loss=="l1":
            loss = l1_loss
        return loss.flatten().mean() if self.reduction == "mean" else loss.flatten().sum()


# we could add a metric like MSE if we want
def crop_mse(pred, targ, causal=False): 
    targ = causal_crop(targ, pred.shape[-1]) if causal else center_crop(targ, pred.shape[-1])
    return ((pred - targ)**2).mean()

Enable logging with WandB:

wandb.login()
wandb: Currently logged in as: drscotthawley (use `wandb login --relogin` to force relogin)
True

Define the fastai Learner and callbacks

We're going to add a new custom WandBAudio callback futher below, that we'll uses when we call fit().

WandBAudio Callback

In order to log audio samples, let's write our own audio-logging callback for fastai:

class WandBAudio(Callback):
    """Progress-like callback: log audio to WandB"""
    order = ProgressCallback.order+1
    def __init__(self, n_preds=5, sample_rate=44100):
        store_attr()

    def after_epoch(self):  
        if not self.learn.training:
            with torch.no_grad():
                preds, targs = [x.detach().cpu().numpy().copy() for x in [self.learn.pred, self.learn.y]]
            log_dict = {}
            for i in range(min(self.n_preds, preds.shape[0])): # note wandb only supports mono
                    log_dict[f"preds_{i}"] = wandb.Audio(preds[i,0,:], caption=f"preds_{i}", sample_rate=self.sample_rate)
            wandb.log(log_dict)

Learner and wandb init

wandb.init(project='time-align')#  no name, name=json.dumps(dict_args))

learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func,
               cbs= [WandbCallback()])
wandb: wandb version 0.12.9 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
Tracking run with wandb version 0.12.2
Syncing run peach-vortex-57 to Weights & Biases (Documentation).
Project page: https://wandb.ai/drscotthawley/time-align
Run page: https://wandb.ai/drscotthawley/time-align/runs/2rkico9o
Run data is saved locally in /home/shawley/fastproaudio/wandb/run-20211219_235350-2rkico9o

Train the model

We can use the fastai learning rate finder to suggest a learning rate:

learn.lr_find(start_lr=1e-6, end_lr=1) 
SuggestedLRs(valley=0.001737800776027143)

And now we'll train using the one-cycle LR schedule, with the WandBAudio callback. (Ignore any warning messages)

epochs = 200 
learn.fit_one_cycle(epochs, lr_max=1e-3, cbs=WandBAudio(sample_rate=args.sample_rate))
# ignore WandbCallback warnings that follow 
Could not gather input dimensions
WandbCallback requires use of "SaveModelCallback" to log best model
WandbCallback was not able to prepare a DataLoader for logging prediction samples -> 
epoch train_loss valid_loss crop_mse time
0 7.481353 6.582051 0.024052 00:06
1 6.594101 6.459403 0.020407 00:06
2 6.081297 6.055995 0.017898 00:06
3 5.702849 5.694269 0.016372 00:05
4 5.367577 5.313276 0.014825 00:06
5 5.041162 4.947532 0.014008 00:05
6 4.688604 4.497010 0.013257 00:06
7 4.231922 4.069275 0.012758 00:05
8 3.947678 3.979271 0.012501 00:06
9 3.904077 3.957376 0.012426 00:06
10 3.809494 3.845006 0.012258 00:06
11 3.835581 3.889054 0.012233 00:06
12 3.731839 4.047597 0.012640 00:06
13 3.693704 3.663044 0.012239 00:06
14 3.671490 3.652269 0.012241 00:06
15 3.632176 3.706145 0.012425 00:06
16 3.635157 3.710897 0.012378 00:06
17 3.606327 3.720078 0.012662 00:06
18 3.579929 3.692073 0.012752 00:06
19 3.553549 3.635737 0.012668 00:06
20 3.557667 3.628015 0.012941 00:06
21 3.550590 3.565692 0.012986 00:06
22 3.517530 3.662328 0.013708 00:06
23 3.478668 3.546507 0.013864 00:06
24 3.424166 3.437811 0.013449 00:06
25 3.397403 3.550436 0.013586 00:06
26 3.359569 3.512471 0.013645 00:06
27 3.317761 3.313176 0.013732 00:06
28 3.305267 3.540271 0.014228 00:06
29 3.274491 3.462339 0.014104 00:06
30 3.246510 3.314009 0.013738 00:05
31 3.209535 3.371364 0.013971 00:06
32 3.234617 3.304230 0.014217 00:06
33 3.195307 3.288021 0.014145 00:06
34 3.175634 3.480893 0.014560 00:06
35 3.210269 3.413791 0.014392 00:05
36 3.195909 3.337735 0.014442 00:06
37 3.175114 3.292834 0.014443 00:06
38 3.157440 3.521394 0.014482 00:06
39 3.190914 3.189670 0.014276 00:06
40 3.123457 3.173429 0.014537 00:06
41 3.166865 3.396136 0.014491 00:06
42 3.151710 3.628282 0.015125 00:06
43 3.133917 3.357393 0.014332 00:06
44 3.233766 3.337438 0.014380 00:06
45 3.198089 3.274935 0.014499 00:06
46 3.156824 3.377739 0.015501 00:06
47 3.137549 3.362537 0.014300 00:06
48 3.148103 3.221290 0.014119 00:06
49 3.102581 3.177467 0.014275 00:06
50 3.142584 3.258965 0.014269 00:06
51 3.143037 3.164217 0.014826 00:06
52 3.121495 3.405983 0.014679 00:06
53 3.218599 3.475188 0.014659 00:06
54 3.226979 3.637296 0.014768 00:06
55 3.167770 3.277506 0.014167 00:06
56 3.116071 3.195097 0.014356 00:06
57 3.084958 3.256184 0.014214 00:06
58 3.071096 3.451308 0.014772 00:06
59 3.132871 3.466042 0.014806 00:06
60 3.147635 3.434664 0.014164 00:06
61 3.086549 3.245490 0.014177 00:06
62 3.132288 3.146089 0.014152 00:06
63 3.109426 3.277438 0.014332 00:06
64 3.115413 3.451010 0.014142 00:06
65 3.142497 3.181455 0.014156 00:06
66 3.118155 3.166035 0.013893 00:06
67 3.101562 3.397317 0.013778 00:06
68 3.031350 3.171115 0.014056 00:06
69 3.045635 3.454994 0.013876 00:06
70 3.156406 3.152696 0.013762 00:06
71 3.101704 3.098252 0.014133 00:06
72 3.170986 3.225254 0.013754 00:06
73 3.140736 3.462520 0.013836 00:06
74 3.077533 3.072246 0.013575 00:06
75 3.044579 3.210176 0.013591 00:06
76 3.052052 3.240107 0.013700 00:06
77 3.050879 3.223253 0.013571 00:06
78 3.045250 3.143569 0.013726 00:06
79 3.021276 3.016362 0.013569 00:06
80 2.956156 3.084649 0.013544 00:06
81 3.076442 3.297247 0.013462 00:06
82 3.060890 3.303410 0.013412 00:06
83 2.990256 3.111420 0.013585 00:06
84 2.997935 3.003673 0.013460 00:06
85 2.954474 3.037837 0.013611 00:06
86 3.090324 3.300683 0.013501 00:06
87 3.065035 3.079030 0.013443 00:06
88 2.996978 3.138466 0.013788 00:06
89 2.970476 3.157427 0.013790 00:06
90 3.023928 2.914707 0.013689 00:06
91 2.986216 3.157763 0.013571 00:06
92 3.042987 3.144066 0.013543 00:06
93 3.048525 3.302791 0.013878 00:06
94 3.059416 3.204684 0.013530 00:06
95 2.986650 3.355709 0.013364 00:06
96 3.203569 3.643076 0.013625 00:06
97 3.126135 3.230669 0.013452 00:06
98 3.048590 3.026189 0.014924 00:06
99 3.183684 3.273628 0.013990 00:05
100 2.998711 2.893994 0.013407 00:06
101 2.957025 3.028701 0.013493 00:06
102 2.912347 2.983477 0.013414 00:06
103 2.906831 2.818719 0.013341 00:06
104 2.895752 3.012321 0.014138 00:05
105 2.912092 2.999137 0.013935 00:06
106 2.887507 3.055579 0.013303 00:06
107 2.785923 2.775874 0.013530 00:06
108 2.790696 2.943224 0.013407 00:06
109 2.736248 2.692200 0.013393 00:06
110 2.777961 3.015282 0.013393 00:06
111 2.691288 2.745592 0.013437 00:06
112 2.789230 3.077176 0.013408 00:06
113 2.915753 3.117961 0.013360 00:06
114 2.959435 3.018226 0.013320 00:06
115 2.860774 3.009152 0.013365 00:06
116 2.801426 3.012006 0.013323 00:06
117 2.745043 2.931821 0.013390 00:05
118 2.829792 3.212095 0.013460 00:06
119 2.928438 3.093549 0.013554 00:06
120 2.809355 2.991477 0.013369 00:06
121 2.842006 2.887894 0.013425 00:06
122 2.752259 3.027925 0.013356 00:06
123 2.717536 2.930707 0.013380 00:06
124 2.659468 2.738417 0.013317 00:06
125 2.632839 2.762275 0.013374 00:06
126 2.615485 2.933614 0.013515 00:06
127 2.772539 2.849458 0.013346 00:06
128 2.692068 2.821917 0.013313 00:06
129 2.674964 2.925591 0.013515 00:06
130 2.735467 3.376877 0.013440 00:06
131 2.766207 2.762553 0.013410 00:06
132 2.667580 2.874986 0.013671 00:06
133 2.749335 3.129861 0.013531 00:06
134 2.729246 3.254984 0.013477 00:06
135 2.695539 2.780476 0.013463 00:06
136 2.816230 2.873311 0.014055 00:06
137 2.864419 2.713563 0.013490 00:06
138 2.710224 2.976030 0.013462 00:06
139 2.810597 2.733291 0.013761 00:06
140 2.720426 3.113232 0.013504 00:06
141 2.684166 2.973913 0.013646 00:06
142 2.658831 2.694964 0.013537 00:06
143 2.551434 2.683068 0.013576 00:06
144 2.607995 2.680787 0.013746 00:06
145 2.511031 2.631121 0.013550 00:06
146 2.532910 2.716851 0.014193 00:06
147 2.499316 2.591249 0.013756 00:06
148 2.504299 2.482930 0.013924 00:06
149 2.454766 2.739512 0.013410 00:06
150 2.410347 2.539483 0.014047 00:06
151 2.357674 2.650381 0.013451 00:06
152 2.516414 3.110652 0.013449 00:06
153 2.511709 2.586704 0.013589 00:06
154 2.395916 2.543531 0.013545 00:06
155 2.325738 2.390544 0.013584 00:06
156 2.317294 2.524722 0.013605 00:06
157 2.320530 2.516788 0.013581 00:06
158 2.342404 2.566997 0.013457 00:06
159 2.357474 2.531506 0.013506 00:06
160 2.353716 2.508490 0.013706 00:06
161 2.288348 2.441255 0.013420 00:06
162 2.280735 2.444789 0.013553 00:06
163 2.253160 2.389358 0.013429 00:06
164 2.194379 2.245656 0.013401 00:06
165 2.221519 2.473610 0.013492 00:06
166 2.213005 2.353466 0.013761 00:06
167 2.189918 2.209893 0.013438 00:06
168 2.137919 2.263444 0.013504 00:06
169 2.122726 2.222945 0.013469 00:06
170 2.120569 2.280406 0.013610 00:06
171 2.116824 2.240684 0.013397 00:06
172 2.122046 2.278723 0.013423 00:06
173 2.112754 2.258956 0.013546 00:06
174 2.102790 2.276128 0.013550 00:06
175 2.076454 2.174576 0.013555 00:06
176 2.057850 2.256071 0.013571 00:06
177 2.060586 2.203174 0.013507 00:06
178 2.050041 2.252553 0.013794 00:06
179 2.037670 2.193828 0.013548 00:06
180 2.053100 2.229102 0.013680 00:06
181 2.035722 2.204854 0.013639 00:06
182 2.027938 2.246414 0.013695 00:06
183 2.019191 2.139696 0.013623 00:06
184 1.986986 2.146453 0.013596 00:06
185 1.961553 2.127800 0.013521 00:06
186 1.962777 2.143257 0.013615 00:06
187 1.961414 2.143002 0.013601 00:06
188 1.944040 2.112002 0.013604 00:06
189 1.938257 2.106662 0.013668 00:06
190 1.929222 2.102477 0.013673 00:06
191 1.925581 2.105546 0.013615 00:06
192 1.916796 2.102625 0.013668 00:06
193 1.909546 2.101063 0.013662 00:06
194 1.893878 2.090551 0.013671 00:06
195 1.880268 2.084363 0.013660 00:06
196 1.871181 2.080494 0.013662 00:06
197 1.861585 2.075231 0.013661 00:06
198 1.857682 2.073144 0.013658 00:06
199 1.853532 2.072075 0.013657 00:06
wandb.finish() # call wandb.finish() after training or your logs may be incomplete

Waiting for W&B process to finish, PID 1810100
Program ended successfully.
Find user logs for this run at: /home/shawley/fastproaudio/wandb/run-20211219_235350-2rkico9o/logs/debug.log
Find internal logs for this run at: /home/shawley/fastproaudio/wandb/run-20211219_235350-2rkico9o/logs/debug-internal.log

Run summary:


crop_mse0.01366
dampening_00
epoch200
lr_00.0
mom_00.95
nesterov_0False
raw_loss1.83874
train_loss1.85353
valid_loss2.07207
wd_00

Run history:


crop_mse█▃▁▁▂▂▂▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
dampening_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr_0▁▂▂▃▄▅▆▇███████▇▇▇▇▇▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
mom_0██▇▆▅▄▃▂▁▁▁▁▁▁▁▂▂▂▂▂▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇█████
nesterov_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss█▅▅▄▄▃▄▃▄▃▃▃▃▃▃▂▃▃▃▃▃▃▃▂▂▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁
train_loss█▅▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▃▂▂▂▂▂▂▁▁▁▁▁▁
valid_loss█▆▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▃▂▂▂▂▂▂▁▁▁▁▁▁▁
wd_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

Synced 5 W&B file(s), 800 media file(s), 0 artifact file(s) and 0 other file(s)
learn.save('time_align')
Path('models/time_align.pth')

Go check out the resulting run logs, graphs, and audio samples at https://wandb.ai/drscotthawley/micro-tcn-fastai, or... lemme see if I can embed some results below:

Inference / Evaluation

Load in the testing data

test_dataset = TimeAlignDataset_fastai(args.root_dir, 
                    preload=args.preload,
                    half=True if args.precision == 16 else False,
                    subset='test',
                    length=args.eval_length,
                    positional_encoding=USE_POSITIONAL_ENCODING)

test_dataloader = torch.utils.data.DataLoader(test_dataset, 
                    shuffle=False,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)

learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func, cbs=[])
learn.load('time_align')
Located 28 examples totaling 3.82 min in the test subset.
/tmp/ipykernel_1809707/2740461433.py:6: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  positional_input = ((torch.arange(seq_length).unsqueeze(0) // 2**torch.arange(c).unsqueeze(1))%2).float()
<fastai.learner.Learner at 0x7fc7ebb7d250>
 

Let's get some predictions from the model. Note that the length of these predictions will greater than in training, because we specified them differently:

print(args.train_length, args.eval_length)
65536 131072

Handy routine to grab some data and run it through the model to get predictions:

def get_pred_batch(dataloader, crop_them=True, causal=True):
    packed, target = next(iter(dataloader))
    input, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
    pred = model.forward(packed.to('cuda:0', dtype=dtype))
    print("pred.shape = ",pred.shape)
    if crop_them: 
        target = causal_crop(target, pred.shape[-1]) if causal else center_crop(target, pred.shape[-1])
        input = causal_crop(input, pred.shape[-1]) if causal else center_crop(input, pred.shape[-1])
    input, params, target, pred = [x.detach().cpu() for x in [input, params, target, pred]]
    return input, params, target, pred
input, params, target, pred = get_pred_batch(test_dataloader, causal=dict_args['causal'])
i = np.random.randint(input.shape[0])  # just look at the first element
print(f"------- i = {i} ---------\n")
print(f"input:")
show_audio(input[i][:USER_INPUT_CHANNELS], sample_rate, mc_plot=False)  # don't show positional encoding
pred.shape =  torch.Size([8, 3, 85152])
------- i = 1 ---------

input:
Shape: (3, 85152), Dtype: torch.float32, Duration: 5.322 s
Max:  1.000,  Min: -1.000, Mean:  0.000, Std Dev:  0.115
print(f"prediction:")
show_audio(pred[i], sample_rate, mc_plot=False)
prediction:
Shape: (3, 85152), Dtype: torch.float32, Duration: 5.322 s
Max:  0.346,  Min: -0.333, Mean: -0.002, Std Dev:  0.044
print(f"target:")
show_audio(target[i], sample_rate, mc_plot=False)
target:
Shape: (3, 85152), Dtype: torch.float32, Duration: 5.322 s
Max:  1.000,  Min: -1.000, Mean:  0.000, Std Dev:  0.122