= False # can set to false to skip this part, e.g. for re-running in same session
install if install: # ffmpeg is to add MP3 support to Colab
!yes | sudo apt install ffmpeg
!pip install -Uqq einops gdown
!pip install -Uqq git+https://github.com/drscotthawley/aeiou
!pip install -Uqq git+https://github.com/drscotthawley/audio-algebra
aa_mixer
Trying to map audio embeddings to vector spaces, for mixing.
Basic setup of hardware environment
= accelerate.Accelerator()
accelerator = HostPrinter(accelerator) # this just prints only on interactive node
hprint = accelerator.device
device #device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
#if torch.backends.mps.is_available():
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
print("device = ",device)
Main parameters for the run/model
= 2
seed
= {'num_quantizers':0, 'sample_size': 65536, 'sample_rate':48000, 'latent_dim': 64, 'pqmf_bands':1, 'ema_decay':0.995, 'num_quantizers':0}
args_dict #global_args = namedtuple("global_args", args_dict.keys())(*args_dict.values())
class DictObj:
def __init__(self, in_dict:dict):
assert isinstance(in_dict, dict), "in_dict is not a dict"
for key, val in in_dict.items():
if isinstance(val, (list, tuple)):
setattr(self, key, [DictObj(x) if isinstance(x, dict) else x for x in val])
else:
setattr(self, key, DictObj(val) if isinstance(val, dict) else val)
= DictObj(args_dict) global_args
Set Up Data Loading
"Setting up dataset")
hprint(= global_args
args = f'{os.getenv("HOME")}/datasets/BDCT-0-chunk-48000'
args.training_dir = 2
args.num_workers
= 256
args.batch_size
= 0.1
load_frac
torch.manual_seed(seed)= AudioDataset([args.training_dir], load_frac=load_frac)
train_set = torchdata.DataLoader(train_set, args.batch_size, shuffle=True,
train_dl =args.num_workers, persistent_workers=True, pin_memory=True)
num_workers
# TODO: need to make val unique. for now just repeat train
= AudioDataset([args.training_dir], load_frac=load_frac/4)
val_set = torchdata.DataLoader(train_set, args.batch_size, shuffle=False,
val_dl =args.num_workers, persistent_workers=True, pin_memory=True)
num_workers
torch.manual_seed(seed)= iter(val_dl)
val_iter = iter(train_dl)
train_iter
print("len(train_set), len(val_set) =",len(train_set), len(val_set))
And let’s listen to a bit of audio
= next(val_iter)
batch = next(val_iter) # two nexts bc i don't like the first one
batch print("batch.shape = ",batch.shape)
0], output_type='live') # clear this output later if you want to keep .ipynb file size small playable_spectrogram(batch[
Set up the Given [Auto]Encoder Model(s)
Note that initially we’re only going to be using the encoder part. The decoder – with all of its sampling code, etc. – will be useful eventualy, and we’ go ahead and define it. But fyi it won’t be used at all while training the AA mixer model.
Download the checkpoint file for the dvae
= os.path.exists('/content')
on_colab if on_colab:
from google.colab import drive
'/content/drive/')
drive.mount(= '/content/drive/MyDrive/AI/checkpoints/epoch=53-step=200000.ckpt'
ckpt_file else:
= 'checkpoint.ckpt'
ckpt_file if not os.path.exists(ckpt_file):
= 'https://drive.google.com/file/d/1C3NMdQlmOcArGt1KL7pH32KtXVCOfXKr/view?usp=sharing'
url # downloading large files from GDrive requires special treatment to bypass the dialog button it wants to throw up
id = url.split('/')[-2]
= f'wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate \'https://docs.google.com/uc?export=download&id={id}\' -O- | sed -rn \'s/.*confirm=([0-9A-Za-z_]+).*/\1\\n/p\')&id={id}" -O {ckpt_file} && rm -rf /tmp/cookies.txt'
cmd print("cmd = \n",cmd)
=True, check=True) subprocess.run(cmd, shell
= DiffusionDVAE.load_from_checkpoint(ckpt_file, global_args=global_args)
given_model eval() # disable randomness, dropout, etc...
given_model.
# attach some arg values to the model
= global_args.sample_size
given_model.demo_samples = global_args.num_quantizers > 0
given_model.quantized
given_model.to(device)# freeze the weights for inference
freeze(given_model) print("Given Autoencoder is ready to go!")
The AA-mixer model
Test that:
= next(train_iter)
batch = get_stems_faders(batch, train_iter, train_dl, maxstems=2)
stems, faders, val_iter print("len(faders) = ",len(faders))
# artificially max out these stems!
for i in range(len(faders)):
= 1/torch.abs(stems[i][0]).max()
faders[i]
0][0]*faders[0], output_type='live') # this is just the batch playable_spectrogram( stems[
1][0]*faders[1], output_type='live') # thisis something new playable_spectrogram( stems[
Mix and apply models
= False # batch norm?
aa_use_bn = True # use residual connections? (doesn't make much difference tbh)
aa_use_resid = global_args.latent_dim # input size to aa model
emb_dims = 64 # number of hidden dimensions in aa model. usually was 64
hidden_dims = False # aa_model is a no-op when this is true
trivial = True
debug print("emb_dims = ",emb_dims)
# untrained aa model
+2)
torch.manual_seed(seed#stems, faders, val_iter = get_stems_faders(batch, val_iter, val_dl)
= AudioAlgebra(dims=emb_dims, hidden_dims=hidden_dims, use_bn=aa_use_bn, resid=aa_use_resid, trivial=trivial).to(device)
aa_model with torch.no_grad():
= do_mixing(stems, faders, given_model, aa_model, device, debug=debug)
zsum, zmix, archive
print("mix:")
'mix'][0], output_type='live') playable_spectrogram( archive[
First, the effects of the given encoder \(f\)
def plot_emb_spectrograms(qs, labels, skip_ys=True):
= plt.subplots( 3 , 1, figsize=(10, 9))
fig, ax for i, (q, name) in enumerate(zip(qs, labels)):
if i>2 and skip_ys: break
= i % 3, i//3
row, col = tokens_spectrogram_image(q, mark_batches=True)
im = (np.array(im.size) *800/im.size[0]).astype(int)
newsize
im.resize(newsize)
ax[row].imshow(im)'off')
ax[row].axis(
ax[row].set_title(labels[i])
plt.tight_layout()
plt.show()
= archive['ys'], archive['ymix'], archive['ysum']
ys, ymix, ysum = ysum - ymix
diff = [ ymix, ysum, diff, ys[0], ys[1]]
qs = ['ymix', 'ysum','diff := ysum - ymix', 'y0', 'y1', ]
labels print("ymix.shape = ",ymix.shape)
plot_emb_spectrograms(qs, labels)
….So at least using the data I can see right now, ymix and ysum can differ by what looks to be 50% in places.
for i, (q, name) in enumerate(zip(qs, labels)):
if i>2: break
print(f"{name}:")
='lines+markers') show_pca_point_cloud(q, mode
Now the z’s (note the model is untrained at this point)
Reconstruction /demo
Define Losses
Main run
Training loop
=True) train_aa_model(debug
if use_wandb: wandb.finish()