Work in progress. If you come back later there'll probably be more. - SHH, 10/24/21.


There are many tutorials on doing audio classification; usually these invovle rendering your audio as (mel-)spectrograms and doing image classification on those. There are not as many tutorials on doing audio processing or generation but there's been a growing list. Lately I've become somewhat proficient with fastai and would like to port some audio processing examples over to it. There are a few choices for tasks and datasets -- great work on source separation, for example.

My Choice: Reproduce Micro-TCN

Since I've been interested in audio effects, I'll choose the task of reproducing Christian Steinmetz and Josh Reiss's Micro-TCN work for learning to profile audio compressors. That code uses PyTorch Lightning instead of fastai. We should be able to do the bare minimum integration with fastai by following Zach Mueller's prescription. The experience gained from doing this can hopefully serve when adapting other audio tasks & models to work with fastai.

Other things we could try (in later notebooks)

We could just grab any old audio data and then we could learn some kind of inverse effect such as denoising: we could add noise to the audio files and then train the network to remove the noise. But what other audio datasets are available?

Installs and imports

# Next line only executes on Colab. Colab users: Please enable GPU in Edit > Notebook settings
! [ -e /content ] && pip install -Uqq pip fastai git+https://github.com/drscotthawley/fastproaudio.git

# Additional installs for this tutorial
%pip install -q fastai_minima torchsummary pyzenodo3 wandb

# Install micro-tcn and auraloss packages (from source, will take a little while)
%pip install -q wheel --ignore-requires-python git+https://github.com/csteinmetz1/micro-tcn.git  git+https://github.com/csteinmetz1/auraloss

# After this cell finishes, restart the kernel and continue below
Note: you may need to restart the kernel to use updated packages.
  WARNING: Missing build requirements in pyproject.toml for git+https://github.com/csteinmetz1/auraloss.
  WARNING: The project does not specify a build backend, and pip cannot fall back to setuptools without 'wheel'.
Note: you may need to restart the kernel to use updated packages.
from fastai.vision.all import *
from fastai.text.all import *
from fastai.callback.fp16 import *
import wandb
from fastai.callback.wandb import *
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from IPython.display import Audio 
import matplotlib.pyplot as plt
import torchsummary
from fastproaudio.core import *
from glob import glob
import json

use_fastaudio = False
if use_fastaudio:
    from fastaudio.core.all import *
    from fastaudio.augment.all import *
    from fastaudio.ci import skip_if_ci

Download and Inspect the Data

The "SignalTrain LA2A Reduced" dataset is something I made Friday night. It's only the first 10 seconds of each of the 20-minute audio files making up the full SignalTrain LA2A dataset, which consists of lots of audio files run through an LA2A audio compressor at different knob settings. At 200 MB, the Reduced version is enough to train the model some and see that it's working for the purposes of this demo, though you'd probably want more data to make a high-quality model. (If you'd rather train using the full 20 GB dataset, use URLs.SIGNALTRAIN_LA2A_1_1 below, but everything will take longer!)

path = get_audio_data(URLs.SIGNALTRAIN_LA2A_REDUCED); path
Path('/home/shawley/.fastai/data/SignalTrain_LA2A_Reduced')
fnames_in = sorted(glob(str(path)+'/*/input*'))
fnames_targ = sorted(glob(str(path)+'/*/*targ*'))
ind = -1   # pick one spot in the list of files
fnames_in[ind], fnames_targ[ind]
('/home/shawley/.fastai/data/SignalTrain_LA2A_Reduced/Val/input_260_.wav',
 '/home/shawley/.fastai/data/SignalTrain_LA2A_Reduced/Val/target_260_LA2A_2c__1__85.wav')

Input audio

import warnings
warnings.filterwarnings("ignore", category=UserWarning)   # turn off annoying matplotlib warnings

waveform, sample_rate = torchaudio.load(fnames_in[ind])
show_audio(waveform, sample_rate)
Shape: (1, 441000), Dtype: torch.float32, Duration: 10.0 s
Max:  0.225,  Min: -0.218, Mean:  0.000, Std Dev:  0.038

Target output audio

target, sr_targ = torchaudio.load(fnames_targ[ind])
show_audio(target, sr_targ)
Shape: (1, 441000), Dtype: torch.float32, Duration: 10.0 s
Max:  0.091,  Min: -0.103, Mean: -0.000, Std Dev:  0.021

Let's look at the difference.

Difference

show_audio(target - waveform, sample_rate)
Shape: (1, 441000), Dtype: torch.float32, Duration: 10.0 s
Max:  0.144,  Min: -0.159, Mean: -0.000, Std Dev:  0.018

Import Steinmetz's Code (and make it work with fastai)

Datasets and Dataloader definitions: These use micro-tcn's custom SignalTrainLA2ADataset class. Each Dataset object returns 3 items from __getitem__(): input, target, and params. To make our lives easier with fastai, we're going to "pack" the input and the params together with our own subclassed Dataset, then we'll have the model unpack them later.

from microtcn.data import SignalTrainLA2ADataset

class SignalTrainLA2ADataset_fastai(SignalTrainLA2ADataset):
    "For fastai's sake, have getitem pack the inputs and params together"
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def __getitem__(self, idx):
        input, target, params = super().__getitem__(idx)
        return torch.cat((input,params),dim=-1), target   # pack input and params together
class Args(object):  # stand-in for parseargs. these are all micro-tcn defaults
    model_type ='tcn'
    root_dir = str(path)
    preload = False
    sample_rate = 44100
    shuffle = True
    train_subset = 'train'
    val_subset = 'val'
    train_length = 65536
    train_fraction = 1.0
    eval_length = 131072
    batch_size = 8   # original is 32, my laptop needs smaller, esp. w/o half precision
    num_workers = 4
    precision = 32  # LEAVE AS 32 FOR NOW: HALF PRECISION (16) NOT WORKING YET -SHH
    n_params = 2
    
args = Args()

#if args.precision == 16:  torch.set_default_dtype(torch.float16)

# setup the dataloaders
train_dataset = SignalTrainLA2ADataset_fastai(args.root_dir, 
                    subset=args.train_subset, 
                    fraction=args.train_fraction,
                    half=True if args.precision == 16 else False, 
                    preload=args.preload, 
                    length=args.train_length)

train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                    shuffle=args.shuffle,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)

val_dataset = SignalTrainLA2ADataset_fastai(args.root_dir, 
                    preload=args.preload,
                    half=True if args.precision == 16 else False,
                    subset=args.val_subset,
                    length=args.eval_length)

val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                    shuffle=False,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)
[(0.0, 0.0), (0.0, 5.0), (0.0, 15.0), (0.0, 20.0), (0.0, 25.0), (0.0, 30.0), (0.0, 35.0), (0.0, 40.0), (0.0, 45.0), (0.0, 55.0), (0.0, 60.0), (0.0, 65.0), (0.0, 70.0), (0.0, 75.0), (0.0, 80.0), (0.0, 85.0), (0.0, 90.0), (0.0, 95.0), (0.0, 100.0), (1.0, 0.0), (1.0, 5.0), (1.0, 15.0), (1.0, 20.0), (1.0, 25.0), (1.0, 30.0), (1.0, 35.0), (1.0, 40.0), (1.0, 45.0), (1.0, 50.0), (1.0, 55.0), (1.0, 60.0), (1.0, 65.0), (1.0, 75.0), (1.0, 80.0), (1.0, 85.0), (1.0, 90.0), (1.0, 95.0), (1.0, 100.0)]
Total Examples: 396     Total classes: 38
Fraction examples: 396    Examples/class: 10
Training with 0.25 min per class    Total of 9.41 min
Located 380 examples totaling 9.41 min in the train subset.
Located 45 examples totaling 2.23 min in the val subset.

If the user requested fp16 precision then we need to install NVIDIA apex:

if False and args.precision == 16:
    %pip install -q --disable-pip-version-check --no-cache-dir git+https://github.com/NVIDIA/apex
    from apex.fp16_utils import convert_network

Define the model(s)

Christian defined a lot of models. We'll do the TCN-300 and the LSTM.

from microtcn.tcn_bare import TCNModel as TCNModel
#from microtcn.lstm import LSTMModel # actually the LSTM depends on a lot of Lightning stuff, so we'll skip that
from microtcn.utils import center_crop, causal_crop

class TCNModel_fastai(TCNModel):
    "For fastai's sake, unpack the inputs and params"
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def forward(self, x, p=None):
        if (p is None) and (self.nparams > 0):  # unpack the params if needed
            assert len(list(x.size())) == 3   # sanity check 
            x, p = x[:,:,0:-self.nparams], x[:,:,-self.nparams:]
        return super().forward(x, p=p)
# micro-tcn defines several different model configurations. I just chose one of them. 
train_configs = [
      {"name" : "TCN-300",
     "model_type" : "tcn",
     "nblocks" : 10,
     "dilation_growth" : 2,
     "kernel_size" : 15,
     "causal" : False,
     "train_fraction" : 1.00,
     "batch_size" : args.batch_size
    }
]

dict_args = train_configs[0]
dict_args["nparams"] = 2

model = TCNModel_fastai(**dict_args)
dtype = torch.float32

Let's take a look at the model:

# this summary allows one to compare the original TCNModel with the TCNModel_fastai
if type(model) == TCNModel_fastai:
    torchsummary.summary(model, [(1,args.train_length)], device="cpu")
else:
    torchsummary.summary(model, [(1,args.train_length),(1,2)], device="cpu")
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 16]              48
              ReLU-2                [-1, 1, 16]               0
            Linear-3                [-1, 1, 32]             544
              ReLU-4                [-1, 1, 32]               0
            Linear-5                [-1, 1, 32]           1,056
              ReLU-6                [-1, 1, 32]               0
            Conv1d-7            [-1, 32, 65520]             480
            Linear-8                [-1, 1, 64]           2,112
       BatchNorm1d-9            [-1, 32, 65520]               0
             FiLM-10            [-1, 32, 65520]               0
            PReLU-11            [-1, 32, 65520]              32
           Conv1d-12            [-1, 32, 65534]              32
         TCNBlock-13            [-1, 32, 65520]               0
           Conv1d-14            [-1, 32, 65492]          15,360
           Linear-15                [-1, 1, 64]           2,112
      BatchNorm1d-16            [-1, 32, 65492]               0
             FiLM-17            [-1, 32, 65492]               0
            PReLU-18            [-1, 32, 65492]              32
           Conv1d-19            [-1, 32, 65520]              32
         TCNBlock-20            [-1, 32, 65492]               0
           Conv1d-21            [-1, 32, 65436]          15,360
           Linear-22                [-1, 1, 64]           2,112
      BatchNorm1d-23            [-1, 32, 65436]               0
             FiLM-24            [-1, 32, 65436]               0
            PReLU-25            [-1, 32, 65436]              32
           Conv1d-26            [-1, 32, 65492]              32
         TCNBlock-27            [-1, 32, 65436]               0
           Conv1d-28            [-1, 32, 65324]          15,360
           Linear-29                [-1, 1, 64]           2,112
      BatchNorm1d-30            [-1, 32, 65324]               0
             FiLM-31            [-1, 32, 65324]               0
            PReLU-32            [-1, 32, 65324]              32
           Conv1d-33            [-1, 32, 65436]              32
         TCNBlock-34            [-1, 32, 65324]               0
           Conv1d-35            [-1, 32, 65100]          15,360
           Linear-36                [-1, 1, 64]           2,112
      BatchNorm1d-37            [-1, 32, 65100]               0
             FiLM-38            [-1, 32, 65100]               0
            PReLU-39            [-1, 32, 65100]              32
           Conv1d-40            [-1, 32, 65324]              32
         TCNBlock-41            [-1, 32, 65100]               0
           Conv1d-42            [-1, 32, 64652]          15,360
           Linear-43                [-1, 1, 64]           2,112
      BatchNorm1d-44            [-1, 32, 64652]               0
             FiLM-45            [-1, 32, 64652]               0
            PReLU-46            [-1, 32, 64652]              32
           Conv1d-47            [-1, 32, 65100]              32
         TCNBlock-48            [-1, 32, 64652]               0
           Conv1d-49            [-1, 32, 63756]          15,360
           Linear-50                [-1, 1, 64]           2,112
      BatchNorm1d-51            [-1, 32, 63756]               0
             FiLM-52            [-1, 32, 63756]               0
            PReLU-53            [-1, 32, 63756]              32
           Conv1d-54            [-1, 32, 64652]              32
         TCNBlock-55            [-1, 32, 63756]               0
           Conv1d-56            [-1, 32, 61964]          15,360
           Linear-57                [-1, 1, 64]           2,112
      BatchNorm1d-58            [-1, 32, 61964]               0
             FiLM-59            [-1, 32, 61964]               0
            PReLU-60            [-1, 32, 61964]              32
           Conv1d-61            [-1, 32, 63756]              32
         TCNBlock-62            [-1, 32, 61964]               0
           Conv1d-63            [-1, 32, 58380]          15,360
           Linear-64                [-1, 1, 64]           2,112
      BatchNorm1d-65            [-1, 32, 58380]               0
             FiLM-66            [-1, 32, 58380]               0
            PReLU-67            [-1, 32, 58380]              32
           Conv1d-68            [-1, 32, 61964]              32
         TCNBlock-69            [-1, 32, 58380]               0
           Conv1d-70            [-1, 32, 51212]          15,360
           Linear-71                [-1, 1, 64]           2,112
      BatchNorm1d-72            [-1, 32, 51212]               0
             FiLM-73            [-1, 32, 51212]               0
            PReLU-74            [-1, 32, 51212]              32
           Conv1d-75            [-1, 32, 58380]              32
         TCNBlock-76            [-1, 32, 51212]               0
           Conv1d-77             [-1, 1, 51212]              33
================================================================
Total params: 162,161
Trainable params: 162,161
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.25
Forward/backward pass size (MB): 922.11
Params size (MB): 0.62
Estimated Total Size (MB): 922.98
----------------------------------------------------------------

Getting the model into fastai form

Zach Mueller made a very helpful fastai_minima package that we'll use, and follow his instructions.

TODO: Zach says I should either use fastai or fastai_minima, not mix them like I'm about to do. But what I have below is the only thing that works right now. ;-)

# I guess we could've imported these up at the top of the notebook...
from torch import optim
from fastai_minima.optimizer import OptimWrapper
#from fastai_minima.learner import Learner  # this doesn't include lr_find()
from fastai.learner import Learner
from fastai_minima.learner import DataLoaders
#from fastai_minima.callback.training_utils import CudaCallback, ProgressCallback # note sure if I need these
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, **kwargs))

dls = DataLoaders(train_dataloader, val_dataloader)

Checking: Let's make sure the Dataloaders are working

if args.precision==16: 
    dtype = torch.float16
    model = convert_network(model, torch.float16)

model = model.to('cuda:0')
if type(model) == TCNModel_fastai:
    print("We're using Hawley's modified code")
    packed, targ = dls.one_batch()
    inp, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
    pred = model.forward(packed.to('cuda:0', dtype=dtype))
else:
    print("We're using Christian's version of Dataloader and model")
    inp, targ, params = dls.one_batch()
    pred = model.forward(inp.to('cuda:0',dtype=dtype), p=params.to('cuda:0', dtype=dtype))
print(f"input  = {inp.size()}\ntarget = {targ.size()}\nparams = {params.size()}\npred   = {pred.size()}")
We're using Hawley's modified code
input  = torch.Size([8, 1, 65536])
target = torch.Size([8, 1, 65536])
params = torch.Size([8, 1, 2])
pred   = torch.Size([8, 1, 51214])

We can make the pred and target the same length by cropping when we compute the loss:

class Crop_Loss:
    "Crop target size to match preds"
    def __init__(self, axis=-1, causal=False, reduction="mean", func=nn.L1Loss):
        store_attr()
        self.loss_func = func()
    def __call__(self, pred, targ):
        targ = causal_crop(targ, pred.shape[-1]) if self.causal else center_crop(targ, pred.shape[-1])
        #pred, targ = TensorBase(pred), TensorBase(targ)
        assert pred.shape == targ.shape, f'pred.shape = {pred.shape} but targ.shape = {targ.shape}'
        return self.loss_func(pred,targ).flatten().mean() if self.reduction == "mean" else loss(pred,targ).flatten().sum()
    

# we could add a metric like MSE if we want
def crop_mse(pred, targ, causal=False): 
    targ = causal_crop(targ, pred.shape[-1]) if causal else center_crop(targ, pred.shape[-1])
    return ((pred - targ)**2).mean()

Enable logging with WandB:

wandb.login()
wandb: Currently logged in as: drscotthawley (use `wandb login --relogin` to force relogin)
True

Define the fastai Learner and callbacks

We're going to add a new custom WandBAudio callback futher below, that we'll uses when we call fit().

WandBAudio Callback

In order to log audio samples, let's write our own audio-logging callback for fastai:

class WandBAudio(Callback):
    """Progress-like callback: log audio to WandB"""
    order = ProgressCallback.order+1
    def __init__(self, n_preds=5, sample_rate=44100):
        store_attr()

    def after_epoch(self):  
        if not self.learn.training:
            with torch.no_grad():
                preds, targs = [x.detach().cpu().numpy().copy() for x in [self.learn.pred, self.learn.y]]
            log_dict = {}
            for i in range(min(self.n_preds, preds.shape[0])): # note wandb only supports mono
                    log_dict[f"preds_{i}"] = wandb.Audio(preds[i,0,:], caption=f"preds_{i}", sample_rate=self.sample_rate)
            wandb.log(log_dict)

Learner and wandb init

wandb.init(project='micro-tcn-fastai')#  no name, name=json.dumps(dict_args))

learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func,
               cbs= [WandbCallback()])

Train the model

We can use the fastai learning rate finder to suggest a learning rate:

learn.lr_find(end_lr=0.1) 
SuggestedLRs(valley=0.0006918309954926372)

And now we'll train using the one-cycle LR schedule, with the WandBAudio callback. (Ignore any warning messages)

epochs = 20  # change to 50 for better results but a longer wait
learn.fit_one_cycle(epochs, lr_max=3e-3, cbs=WandBAudio(sample_rate=args.sample_rate))
Could not gather input dimensions
WandbCallback requires use of "SaveModelCallback" to log best model
WandbCallback was not able to prepare a DataLoader for logging prediction samples -> 
epoch train_loss valid_loss crop_mse time
0 0.143242 0.098410 0.020299 00:06
1 0.096335 0.061745 0.007963 00:05
2 0.065788 0.035349 0.003570 00:05
3 0.045120 0.027977 0.001921 00:05
4 0.034311 0.023991 0.001443 00:05
5 0.026962 0.020367 0.001035 00:06
6 0.023846 0.020088 0.000883 00:05
7 0.021708 0.015346 0.000704 00:06
8 0.019866 0.026435 0.001117 00:06
9 0.017529 0.012842 0.000533 00:05
10 0.016500 0.013006 0.000504 00:05
11 0.015390 0.011723 0.000425 00:06
12 0.014275 0.012459 0.000437 00:06
13 0.013890 0.012470 0.000408 00:05
14 0.013401 0.013570 0.000454 00:05
15 0.012933 0.011421 0.000390 00:06
16 0.012545 0.010564 0.000362 00:05
17 0.012153 0.011395 0.000392 00:05
18 0.011879 0.010478 0.000356 00:05
19 0.011740 0.010412 0.000361 00:06
wandb.finish() # call wandb.finish() after training or your logs may be incomplete

Waiting for W&B process to finish, PID 1852379... (success).

Run history:


crop_mse█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dampening_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr_0▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
mom_0██▇▆▅▄▃▂▁▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇█████
nesterov_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss█▆▄▃▃▃▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss█▇▅▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss█▅▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁
wd_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

Run summary:


crop_mse0.00036
dampening_00
epoch20
lr_00.0
mom_00.95
nesterov_0False
raw_loss0.01222
train_loss0.01174
valid_loss0.01041
wd_00
Synced 5 W&B file(s), 100 media file(s), 0 artifact file(s) and 0 other file(s)
Synced fresh-salad-56: https://wandb.ai/drscotthawley/micro-tcn-fastai/runs/9w1h46em
Find logs at: ./wandb/run-20211025_091818-9w1h46em/logs
learn.save('micro-tcn-fastai')
Path('models/micro-tcn-fastai.pth')

Go check out the resulting run logs, graphs, and audio samples at https://wandb.ai/drscotthawley/micro-tcn-fastai, or... lemme see if I can embed some results below:

...ok it looks like the WandB results iframe (with cool graphs & audio) is getting filtered out of the docs (by nbdev and/or jekyll), but if you open this notebook file -- e.g. click the "Open in Colab" badge at the top -- then scroll down and you'll see the report. Or just go to the WandB link posted above!

TODO: Inference / Evaluation / Analysis

Load in the testing data

test_dataset = SignalTrainLA2ADataset_fastai(args.root_dir, 
                    preload=args.preload,
                    half=True if args.precision == 16 else False,
                    subset='test',
                    length=args.eval_length)

test_dataloader = torch.utils.data.DataLoader(test_dataset, 
                    shuffle=False,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)

learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func, cbs=[])
learn.load('micro-tcn-fastai')
Located 9 examples totaling 0.45 min in the test subset.
<fastai.learner.Learner at 0x7fcff933b430>

^^ 9 examples? I thought there were only 3:

!ls {path}/Test
input_235_.wav	input_259_.wav		       target_256_LA2A_2c__1__65.wav
input_256_.wav	target_235_LA2A_2c__0__65.wav  target_259_LA2A_2c__1__80.wav

...Ok I don't understand that yet. Moving on:

Let's get some predictions from the model. Note that the length of these predictions will greater than in training, because we specified them differently:

print(args.train_length, args.eval_length)
65536 131072

Handy routine to grab some data and run it through the model to get predictions:

def get_pred_batch(dataloader, crop_target=True, causal=False):
    packed, target = next(iter(dataloader))
    input, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
    pred = model.forward(packed.to('cuda:0', dtype=dtype))
    if crop_target: target = causal_crop(target, pred.shape[-1]) if causal else center_crop(target, pred.shape[-1])
    input, params, target, pred = [x.detach().cpu() for x in [input, params, target, pred]]
    return input, params, target, pred
input, params, target, pred = get_pred_batch(test_dataloader, causal=dict_args['causal'])
i = 0  # just look at the first element
print(f"------- i = {i} ---------\n")
print(f"prediction:")
show_audio(pred[i], sample_rate)
------- i = 0 ---------

prediction:
Shape: (1, 116750), Dtype: torch.float32, Duration: 2.647392290249433 s
Max:  0.139,  Min: -0.147, Mean:  0.000, Std Dev:  0.037
print(f"target:")
show_audio(target[i], sample_rate)
target:
Shape: (1, 116750), Dtype: torch.float32, Duration: 2.647392290249433 s
Max:  0.215,  Min: -0.202, Mean:  0.000, Std Dev:  0.053

TODO: More. We're not finished. I'll come back and add more to this later.

Deployment / Plugins

Check out Christian's GitHub page for micro-tcn where he provides instructions and JUCE files by which to render the model as an audio plugin. Pretty sure you can only do this with the causal models, which I didn't include -- yet!