datasets

Routines for loading/handling datasets

(Many of these routines originally appeared in “audio-diffusion” repo by Zach Evans w/ contributions by Scott Hawley https://github.com/zqevans/audio-diffusion/blob/main/diffusion/utils.py but have since been modified.)

Augmentation routines

These support ‘pipelining’ in the sense of

        self.augs = eval(f'torch.nn.Sequential( {augs} )')  

(see AudioDataset below for sample invocation), to wit: they try to return a similar datatype to what was passed in to .forward(), be it torch.Tensor or dict.

Dict pipeline usage

“Get a dict, give a dict” is the incredibly awkward but amazingly accurate summary of this policy.

Some routines may add additional info to a returned dict, if possible.

Reserved keys:

  • “inputs”: used both for the input and the output of the routine, i.e. “inputs” gets overwritten, i.e. obliterated, i.e. such operations are in-place. If you want to store an unaltered archival copy on an input for later use, then create a new dict key (and maybe even use .clone()).

For the dict-enabled pipeline, a lot of the return operations will be the same, so…


source

pipeline_return

 pipeline_return (val, x, key='inputs')

little helper routine that appears at end of most augmentations, to compress code

Type Default Details
val value to be returned (by calling function)
x original data-container that was passed in (tensor or dict)
key str inputs if x is dict, this key gets overwritten/added

source

RandomGain

 RandomGain (min_gain, max_gain)

apply a random gain to audio

Details
min_gain minimum gain to apply
max_gain maximum gain to apply

Testing RandomGain:

audio = torch.rand(8)
print(f"audio = {audio}")
gain_op = RandomGain(-2.0,2.0)
audio2 = gain_op(audio)  # audio does not get overwritten
print(f"audio NOT OVERWRITTEN = {audio}\naudio2 = {audio2}\n")
audio = tensor([0.3984, 0.7110, 0.4340, 0.0856, 0.5396, 0.9248, 0.8500, 0.4580])
audio NOT OVERWRITTEN = tensor([0.3984, 0.7110, 0.4340, 0.0856, 0.5396, 0.9248, 0.8500, 0.4580])
audio2 = tensor([-0.2905, -0.5185, -0.3165, -0.0624, -0.3935, -0.6744, -0.6198, -0.3340])

Note how, with the dict version of the pipeline, the inputs element of the dict gets overwritten:

x = {'inputs':audio}
print(f"x['inputs'] = {x['inputs']}")
audio2 = gain_op(x)  # x['inputs'] gets overwritten but audio does not
print(f"audio NOT OVERWRITTEN = {audio}\nx['inputs'] OVERWRITTEN = {x['inputs']} ")
print(f"audio2 = {audio2}")
assert torch.equal(audio2['inputs'], x['inputs']), "Oh no.  NO idea what's going on"
print("\naudio2['inputs'] == x['inputs']: True, i.e. x['inputs'] was overwritten")
x['inputs'] = tensor([0.3984, 0.7110, 0.4340, 0.0856, 0.5396, 0.9248, 0.8500, 0.4580])
audio NOT OVERWRITTEN = tensor([0.3984, 0.7110, 0.4340, 0.0856, 0.5396, 0.9248, 0.8500, 0.4580])
x['inputs'] OVERWRITTEN = tensor([-0.7592, -1.3550, -0.8272, -0.1631, -1.0283, -1.7625, -1.6198, -0.8728]) 
audio2 = {'inputs': tensor([-0.7592, -1.3550, -0.8272, -0.1631, -1.0283, -1.7625, -1.6198, -0.8728])}

audio2['inputs'] == x['inputs']: True, i.e. x['inputs'] was overwritten

source

PadCrop

 PadCrop (n_samples, randomize=True, redraw_silence=True,
          silence_thresh=-60, max_redraws=2)

Grabs a randomly-located section from an audio file, padding with zeros in case of any misalignment

Type Default Details
n_samples length of chunk to extract from longer signal
randomize bool True draw cropped chunk from a random position in audio file
redraw_silence bool True a chunk containing silence will be replaced with a new one
silence_thresh int -60 threshold in dB below which we declare to be silence
max_redraws int 2 when redrawing silences, don’t do it more than this many

Variation on PadCrop. source: Zach Evan’s audio-diffusion repo:


source

PadCrop_Normalized_T

 PadCrop_Normalized_T (n_samples:int, sample_rate:int,
                       randomize:bool=True)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool


source

PadCrop_Normalized_T_old

 PadCrop_Normalized_T_old (n_samples:int, randomize:bool=True)

Variation on PadCrop. source: Zach Evan’s audio-diffusion repo

Testing PadCrop()

audio = torch.rand(8)
crop_op = PadCrop(3)
torch.random.manual_seed(0)
crop1 = crop_op(audio)
print("audio = ",audio)
print("crop1 = ",crop1)                        # raw tensor version

torch.random.manual_seed(0)
crop_dict = crop_op({'inputs':audio})   # dict version
print("crop_dict = ",crop_dict)   # dict version

assert torch.equal(crop1, crop_dict['inputs']), f"These should be equal: {crop1}, {crop_dict['inputs']}"
print('crop1 == crop_dict: Success!')
audio =  tensor([0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964, 0.4556])
crop1 =  tensor([[0.1320, 0.3074, 0.6341]])
crop_dict =  {'inputs': tensor([[0.1320, 0.3074, 0.6341]]), 'crop_range': tensor([2, 5])}
crop1 == crop_dict: Success!

And test the dict version:

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
x = {'inputs':torch.rand(8).to(device)}
crop_op = PadCrop(3)
torch.random.manual_seed(0)
crop = crop_op(x)
print("crop =",crop)
crop = {'inputs': tensor([[0.1320, 0.3074, 0.6341]], device='cuda:0'), 'crop_range': tensor([2, 5], device='cuda:0')}

source

PhaseFlipper

 PhaseFlipper (p=0.5)

she was PHAAAAAAA-AAAASE FLIPPER, a random invert yeah

Type Default Details
p float 0.5 probability that phase flip will be applied

source

FillTheNoise

 FillTheNoise (p=0.33)

randomly adds a bit of noise, or not, just to spice things up. (Name is an homage to DJ/artist/collaborator Kill the Noise)

Type Default Details
p float 0.33 probability that noise will be added

source

RandPool

 RandPool (p=0.2)

maybe (or maybe not) do an avgpool operation, with a random-sized kernel


source

NormInputs

 NormInputs (do_norm=True)

Normalize inputs to [-1,1]. Useful for quiet inputs

Type Default Details
do_norm bool True controllable parameter for turning normalization on/off

source

Mono

 Mono (*args, **kwargs)

convert audio to mono


source

Stereo

 Stereo (*args, **kwargs)

convert audio to stereo

Masking (of inputs)

First a couple utility routines before the main masking routine:


source

smoothstep

 smoothstep (x, edge0=0.4, edge1=0.6)

an s-shaped curve, 0’s on left side and 1’s at right side, with gradient zero at all 1’s and 0’s. cf. https://en.wikipedia.org/wiki/Smoothstep

Type Default Details
x a tensor of coordinates across a domain, e.g. [0,1]
edge0 float 0.4 “zero”/“left” side of smoothstep
edge1 float 0.6 “one”/“right” side of smoothstep

source

smoothstep_box

 smoothstep_box (coords, edges=(0.2, 0.3, 0.5, 0.6))

makes a flat region of zeros that transitions smoothly to 1’s via smoothsteps at the sides

Type Default Details
coords tensor of coordinate values
edges tuple (0.2, 0.3, 0.5, 0.6) (left 1’s boundary, left 0’s boundary, right 0’s boundary, right 1’s boundary)

Testing smoothstep_box:

import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"device = {device}")
x = torch.linspace(0,1,steps=100).to(device)
y = smoothstep_box(x)
plt.plot(x.cpu(), y.cpu())
plt.xlabel('x')
plt.show()
device = cuda

And now the main masking routine:


source

RandMask1D

 RandMask1D (mask_frac=0.25, mask_width=0.1, mask_type='simple',
             edge_width=0.2, per_channel=False, verbose=False)

Performs masking or ‘cutout’ along 1d data. Can support ‘smooth sides’ to the cutouts. Note that you probably want masking to be the last step in the augmentation pipeline

Type Default Details
mask_frac float 0.25 fraction of total input that is to be masked (helps compute no. of masked regions)
mask_width float 0.1 either a fraction of the total length (float < 1) or an exact integer value for length of each masked region
mask_type str simple ‘simple’=hard sides to cuts, ‘softstep’=smooth sides, ‘nyquist’=nyquist-freq wave 0.5*(1,-1,1,-1,..)
edge_width float 0.2 for mask_type=smoothstep, fraction or integer value of transition regions to come in from the sides of zeros region
per_channel bool False different masks on different channels; model can cheat if your inputs are mono
verbose bool False show logging info

Let’s test the simple mask (“hard cuts”):

torch.manual_seed(0)
audio = (2*torch.rand((2,2,5000))-1)#.to(device)
mask_op = RandMask1D(mask_frac=0.3, mask_width=0.1, verbose=True)
masked = mask_op.forward(audio)

# routine to display what we got
def display_mask_data(audio, mask_op, masked):
    fig, ax = plt.subplots(1,3,figsize=(14,4))
    ax[1].plot(mask_op.mask.cpu(), label='single mask')
    for c in range(audio.shape[1]): # show different channels of masked audio
        ax[0].plot(audio[0,0,:].cpu(),  alpha=0.4, label=f'audio, channel{c}')
        ax[2].plot(masked[0,c,:].cpu(), alpha=0.4, label=f'masked audio, channel{c}')
    for i in range(3): ax[i].legend()
    plt.show() 

display_mask_data(audio, mask_op, masked)

 MMMM-  RandMask1D: Mask engaged!  self.mask_width, self.n_masks =  500 3 

Now test the mask with “softstep” sides and the “per-channel” masking:

mask_op = RandMask1D(mask_frac=0.4, mask_width=0.2, mask_type='smoothstep', edge_width=0.3, verbose=True, per_channel=True)
masked = mask_op.forward(audio)
display_mask_data(audio, mask_op, masked)

 MMMM-  RandMask1D: Mask engaged!  self.mask_width, self.n_masks =  1000 2 

…and lets make sure the dict version retains the “unmasked” audio, unaltered:

masked = mask_op.forward({'inputs':audio})
display_mask_data(masked['unmasked'], mask_op, masked['inputs'] )

The idea behind the Nyquist freq replacement is that it could perhaps serve as a “mask code” that is different from (musically relevant) silence. And hopefully the neural network picks up on it while the human ear does not!

mask_op = RandMask1D(mask_frac=0.3, mask_width=0.3, mask_type='nyquist', verbose=True)
masked = mask_op.forward(audio)
display_mask_data(audio, mask_op, masked)

 MMMM-  RandMask1D: Mask engaged!  self.mask_width, self.n_masks =  1500 1 

AudioDataset class

The flagship class!


source

AudioDataset

 AudioDataset (paths, sample_rate=48000, sample_size=65536,
               random_crop=True, load_frac=1.0, cache_training_data=False,
               num_gpus=8, redraw_silence=True, silence_thresh=-60,
               max_redraws=2, augs='Stereo(), PhaseFlipper()',
               verbose=False, return_dict=False)

Reads from a tree of directories and serves up cropped bits from any and all audio files found therein. For efficiency, best if you “chunk” these files via chunkadelic modified from https://github.com/drscotthawley/audio-diffusion/blob/main/dataset/dataset.py

Type Default Details
paths list of strings of directory (/tree) names to draw audio files from
sample_rate int 48000 audio sample rate in Hz
sample_size int 65536 how many audio samples in each “chunk”
random_crop bool True take chunks from random positions within files
load_frac float 1.0 fraction of total dataset to load
cache_training_data bool False True = pre-load whole dataset into memory (not fully supported)
num_gpus int 8 used only when cache_training_data=True, to avoid duplicates,
redraw_silence bool True a chunk containing silence will be replaced with a new one
silence_thresh int -60 threshold in dB below which we declare to be silence
max_redraws int 2 when redrawing silences, don’t do it more than this many
augs str Stereo(), PhaseFlipper() list of augmentation transforms after PadCrop, as a string
verbose bool False whether to print notices of reasampling or not
return_dict bool False False=return raw audio only, True=return dict of all kinds of info

Quick check to catch minor errors:

dataset = AudioDataset('examples/', augs='Stereo(), PhaseFlipper(), FillTheNoise(), NormInputs()')
signal = dataset.__getitem__(0)
print("signal.shape =",signal.shape)

print("\nStereo -------------")
dataset2 = AudioDataset('examples/', augs='Stereo(), PhaseFlipper()')
signal2 = dataset2.__getitem__(0)
print("signal2.shape =",signal2.shape)
augs = Stereo(), PhaseFlipper(), FillTheNoise(), NormInputs()
AudioDataset:2 files found.
signal.shape = torch.Size([2, 65536])

Stereo -------------
augs = Stereo(), PhaseFlipper()
AudioDataset:2 files found.
signal2.shape = torch.Size([2, 65536])

Check newer aug pipeline features: dict and masking:

print("\Dict & Masked -------------")
dataset3 = AudioDataset('examples/', augs='Stereo(), PhaseFlipper(), RandMask1D(verbose=True)',return_dict=true)
x3 = dataset3.__getitem__(0)
print("x3['inputs'].shape =",x3['inputs'].shape)
print("x3 = ",x3)
\Dict & Masked -------------
augs = Stereo(), PhaseFlipper(), RandMask1D(verbose=True)
AudioDataset:2 files found.

 MMMM-  RandMask1D: Mask engaged!  self.mask_width, self.n_masks =  6553 2 

x3['inputs'].shape = torch.Size([2, 65536])
x3 =  {'filename': 'examples/stereo_pewpew.mp3', 'inputs': tensor([[-0.1429, -0.1227, -0.1061,  ..., -0.0312, -0.0317, -0.0323],
        [ 0.0950,  0.0607,  0.0258,  ..., -0.0045, -0.0045, -0.0045]]), 'crop_range': tensor([ 48459, 113995]), 'unmasked': tensor([[-0.1429, -0.1227, -0.1061,  ..., -0.0312, -0.0317, -0.0323],
        [ 0.0950,  0.0607,  0.0258,  ..., -0.0045, -0.0045, -0.0045]])}

Test how the DataLoader behaves in dict pipeline mode:

from torch.utils.data import DataLoader

For the non-dict pipeline, audio data is naturally batched already into a tensor:

dataset = AudioDataset('examples/', augs='Stereo(), PhaseFlipper()')
train_dl = DataLoader(dataset, batch_size=2, shuffle=True)
batch = next(iter(train_dl))
print("batch =\n",batch)
batch.shape
augs = Stereo(), PhaseFlipper()
AudioDataset:2 files found.
batch =
 tensor([[[ 0.2433,  0.2356,  0.2265,  ..., -0.0054, -0.0056, -0.0067],
         [-0.0189, -0.0315, -0.0418,  ...,  0.1011,  0.0883,  0.0695]],

        [[ 0.0003,  0.0004,  0.0006,  ..., -0.0000, -0.0000, -0.0000],
         [ 0.0003,  0.0004,  0.0006,  ..., -0.0000, -0.0000, -0.0000]]])
torch.Size([2, 2, 65536])

Whereas for the dict pipeline, we end up with a lot more information. Thankfully DataLoader already converts our batch of dicts to a dict of batches, automatically.

dataset = AudioDataset('examples/', augs='Stereo(), PhaseFlipper()', return_dict=true)
train_dl = DataLoader(dataset, batch_size=2, shuffle=True)
batch = next(iter(train_dl))
print("batch =\n",batch)
batch['inputs'].shape
augs = Stereo(), PhaseFlipper()
AudioDataset:2 files found.
batch =
 {'filename': ['examples/stereo_pewpew.mp3', 'examples/example.wav'], 'inputs': tensor([[[ 0.0455,  0.0441,  0.0427,  ...,  0.0846,  0.0718,  0.0537],
         [-0.0023, -0.0020, -0.0015,  ..., -0.0283, -0.0262, -0.0253]],

        [[-0.0003, -0.0004, -0.0006,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0003, -0.0004, -0.0006,  ...,  0.0000,  0.0000,  0.0000]]]), 'crop_range': tensor([[ 35643, 101179],
        [     0,  65536]])}
torch.Size([2, 2, 65536])

WebDataset support

Background Info

Refer to the official WebDataset Repo on GitHub.

WebDataset makes it easy to write I/O pipelines for large datasets. Datasets can be stored locally or in the cloud.

They use the word “shards” but never define what “shard” means. I (S.H.) surmise they mean the groups of data files which are gathered into a series of .tar files – the .tar files are the shards?

cf. Video Tutorial: “Loading Training Data with WebDataset”.

The recommended usage for AWS S3 can be seen in [this GitHub Issue comment by tmbdev] (https://github.com/webdataset/webdataset/issues/21#issuecomment-706008342):

url = "pipe:s3cmd get s3://bucket/dataset-{000000..000999}.tar -"
dataset = wds.Dataset(url)...

1 s3cmd get should read aws s3 cp.

That URL is expecting a contiguously-numbered range of .tar files. So if the file numbers are contiguous (no gaps), then we’ll have an easy time. Otherwise, there are ways to pass in a long list of similar “pipe:…tar” ‘urls’ for each and every tar file, which is still not a big deal though it may appear messier.

AWS hates double slashes, so we’ll use the following


source

fix_double_slashes

 fix_double_slashes (s, debug=False)

aws is pretty unforgiving compared to ‘normal’ filesystems. so here’s some ‘cleanup’

Test that:

s = 's3://hey///ho//lets/go'
print(fix_double_slashes(s, debug=True))
s3://hey/ho/lets/go

NOTE: be prepared for extensive ‘testing cases’ shown for the following routines.

General utility: get_s3_contents()


source

get_s3_contents

 get_s3_contents (dataset_path, s3_url_prefix='s3://s-laion-
                  audio/webdataset_tar/', filter='', recursive=True,
                  debug=False, profile='default')

Gets a list of names of files or subdirectories on an s3 path

Type Default Details
dataset_path “name” of the dataset on s3
s3_url_prefix str s3://s-laion-audio/webdataset_tar/ s3 bucket to check
filter str only grab certain filename / extensions
recursive bool True check all subdirectories. RECOMMEND LEAVING THIS TRUE
debug bool False print debugging info (don’t rely on this info staying consistent)
profile str default name of the AWS profile credentials

Let’s test that on the FSD50K dataset: > Note: These tests will only work on systems on which you have valid AWS credentials for the S3 buckets in question. If the docs show a bunch of blanks in what follows, it’s because they were generated on a system without such credentials.

get_s3_contents('130000_MIDI_SONGS')[:10]
['webdataset_tar/130000_MIDI_SONGS/130000_Pop_Rock_Classical_Videogame_EDM_MIDI_Archive[6_19_15]/test/sizes.json',
 'webdataset_tar/130000_MIDI_SONGS/130000_Pop_Rock_Classical_Videogame_EDM_MIDI_Archive[6_19_15]/train/sizes.json',
 'webdataset_tar/130000_MIDI_SONGS/2/test/0.tar',
 'webdataset_tar/130000_MIDI_SONGS/2/test/sizes.json',
 'webdataset_tar/130000_MIDI_SONGS/2/train/0.tar',
 'webdataset_tar/130000_MIDI_SONGS/2/train/1.tar',
 'webdataset_tar/130000_MIDI_SONGS/2/train/2.tar',
 'webdataset_tar/130000_MIDI_SONGS/2/train/3.tar',
 'webdataset_tar/130000_MIDI_SONGS/2/train/4.tar',
 'webdataset_tar/130000_MIDI_SONGS/2/train/5.tar']
get_s3_contents('FSD50K/test/')[:10]
['webdataset_tar/FSD50K/test/0.tar',
 'webdataset_tar/FSD50K/test/1.tar',
 'webdataset_tar/FSD50K/test/10.tar',
 'webdataset_tar/FSD50K/test/11.tar',
 'webdataset_tar/FSD50K/test/12.tar',
 'webdataset_tar/FSD50K/test/13.tar',
 'webdataset_tar/FSD50K/test/14.tar',
 'webdataset_tar/FSD50K/test/15.tar',
 'webdataset_tar/FSD50K/test/16.tar',
 'webdataset_tar/FSD50K/test/17.tar']

And let’s try filtering for only tar files:

tar_names = get_s3_contents('FSD50K/test', filter='.tar')
tar_names
['webdataset_tar/FSD50K/test/0.tar',
 'webdataset_tar/FSD50K/test/1.tar',
 'webdataset_tar/FSD50K/test/10.tar',
 'webdataset_tar/FSD50K/test/11.tar',
 'webdataset_tar/FSD50K/test/12.tar',
 'webdataset_tar/FSD50K/test/13.tar',
 'webdataset_tar/FSD50K/test/14.tar',
 'webdataset_tar/FSD50K/test/15.tar',
 'webdataset_tar/FSD50K/test/16.tar',
 'webdataset_tar/FSD50K/test/17.tar',
 'webdataset_tar/FSD50K/test/18.tar',
 'webdataset_tar/FSD50K/test/19.tar',
 'webdataset_tar/FSD50K/test/2.tar',
 'webdataset_tar/FSD50K/test/3.tar',
 'webdataset_tar/FSD50K/test/4.tar',
 'webdataset_tar/FSD50K/test/5.tar',
 'webdataset_tar/FSD50K/test/6.tar',
 'webdataset_tar/FSD50K/test/7.tar',
 'webdataset_tar/FSD50K/test/8.tar',
 'webdataset_tar/FSD50K/test/9.tar']

List all LAION audio datasets:

get_s3_contents('',recursive=False)[:20]
['130000_MIDI_SONGS/',
 'Audiostock_music/',
 'BBCSoundEffects/',
 'CMU_Arctic/',
 'CREMA-D/',
 'Cambridge_mt/',
 'Clotho/',
 'CoVoST_2/',
 'ESC50_1/',
 'ESC50_2/',
 'ESC50_3/',
 'ESC50_4/',
 'ESC50_5/',
 'EmoV_DB/',
 'Europarl-st/',
 'FMA/',
 'FMA_stereo/',
 'FMA_updated/',
 'FSD50K/',
 'GTZAN/']

For contiguous file-number lists…

Maybe the range of tar numbers is contigous. (In the LAION AudoiDataset archives, they are each contiguous within train/, valid/, and test/ subsets.) If so, let’s have something to output that range:


source

get_contiguous_range

 get_contiguous_range (tar_names)

given a string of tar file names, return a string of their numerical range if the numbers are contiguous. Otherwise return empty string

Details
tar_names list of tar file names, although the .tar part is actually optional
cont_range = get_contiguous_range(tar_names)
cont_range
'{0..19}'

Test if leading zeros are preserved:

get_contiguous_range(['0000'+x for x in tar_names])
'{0..19}'

Test zero-element and single element versions:

print(get_contiguous_range([]))
print(get_contiguous_range([1]))

1

And show that ‘.tar’ is optional:

get_contiguous_range(['01','02','3'])
'{01..3}'

….So, if a contiguous range of tar file names is available in a WebDataset directory, then we can just use the native WebDataset creation utilities and can ignore all the other %$#*& that’s about to follow below.

Let’s test the simple version first:

s3_url_prefix='s3://s-laion-audio/webdataset_tar/'
url = f"pipe:aws s3 cp {s3_url_prefix}FSD50K/test/{cont_range}.tar -"  # 'aws get' is not a thing. 'aws cp' is
print(url)
dataset = wds.WebDataset(url)
pipe:aws s3 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/{0..19}.tar -

WebDataset is a kind of IterableDataset, so we can iterate over it directly:

## NOTE TO SELF: DON'T RUN THIS ON STABILITY CLUSTER HEADNODE (But Jupyter nodes are fine)
try:
    for sample in dataset:  
        for k,v in sample.items():  # print the all entries in dict
            print(f"{k:20s} {repr(v)[:50]}")
        break                       # abort after first dict
except:
    sample = None
__key__              './mnt/audio_clip/processed_datasets/FSD50K/test/3
__url__              'pipe:aws s3 cp s3://s-laion-audio/webdataset_tar/
flac                 b'fLaC\x00\x00\x00"\x12\x00\x12\x00\x00\x0ee\x00\x
json                 b'{"text": ["The sounds of Aircraft, Engine, Fixed
if sample: 
    audio_keys = ("flac")
    found_key, rewrite_key = '', ''
    for k,v in sample.items():  
        for akey in audio_keys:
            if k.endswith(akey): 
                found_key, rewrite_key = k, akey
                break
        if '' != found_key: break 
    if '' == found_key:  # got no audio!   
        print("Error: No audio in this sample:")
        for k,v in sample.items():  # print the all entries in dict
            print(f"{k:20s} {repr(v)[:50]}")
    else:
        print("Found flac")
        flac = sample['flac']
Found flac

There’s a built-in decoder for various audio formats, so we can just use:

from aeiou.viz import audio_spectrogram_image
from IPython.display import display 
if sample:
    dataset = wds.WebDataset(url).decode(wds.torch_audio) # throw out the json
    sample = next(iter(dataset))
    audio, sr = (sample["flac"])
    audio = audio[:,:min(audio.shape[-1], 128000)]
    print('audio.shape = ',audio.shape)
    #(audio, specs='wave_mel', output_type='live')
    spec_graph = audio_spectrogram_image(audio, justimage=False, db=False, db_range=[-60,20])
    display(spec_graph)
audio.shape =  torch.Size([1, 128000])
/fsx/shawley/envs_sm/aa/lib/python3.10/site-packages/torchaudio/transforms/_transforms.py:611: UserWarning: Argument 'onesided' has been deprecated and has no influence on the behavior of this module.
  warnings.warn(
/fsx/shawley/envs_sm/aa/lib/python3.10/site-packages/torchaudio/functional/functional.py:576: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (513) may be set too low.
  warnings.warn(

Non-contiguously-numbered lists of tar files…

Here we’ll just get individual URLs for every tar file possible for a given list of dataset names


source

get_all_s3_urls_zach

 get_all_s3_urls_zach (names=[], subsets=[''], s3_url_prefix=None,
                       recursive=True, filter_str='tar', debug=False,
                       profiles={})

get urls of shards (tar files) for multiple datasets in one s3 bucket

Type Default Details
names list [] list of all valid [LAION AudioDataset] dataset names
subsets list [’’] list of subsets you want from those datasets, e.g. [‘train’,‘valid’]
s3_url_prefix NoneType None prefix for those dataset names
recursive bool True recursively list all tar files in all subdirs
filter_str str tar only grab files with this substring
debug bool False print debugging info – note: info displayed likely to change at dev’s whims
profiles dict {} dictionary of profiles for each item in names, e.g. {‘dataset1’: ‘profile1’, ‘dataset2’: ‘profile2’}

source

get_all_s3_urls

 get_all_s3_urls (names=[], subsets=[''], s3_url_prefix=None,
                  recursive=True, filter_str='tar', debug=False,
                  profiles={}, **kwargs)

get urls of shards (tar files) for multiple datasets in one s3 bucket

Type Default Details
names list [] list of all valid [LAION AudioDataset] dataset names, can include URLs in which case s3_url_prefix is ignored
subsets list [’’] list of subsets you want from those datasets, e.g. [‘train’,‘valid’]
s3_url_prefix NoneType None prefix for those dataset names if no s3:// supplied in names, e.g. ‘s3://s-laion-audio/webdataset_tar/’
recursive bool True recursively list all tar files in all subdirs
filter_str str tar only grab files with this substring
debug bool False print debugging info – note: info displayed likely to change at dev’s whims
profiles dict {} list of S3 profiles to use, e.g. {‘s3://s-laion-audio’:‘default’}
kwargs
names = [
        #"s3://s-harmonai/datasets/songs_raw/songs_md_16bit_mono/",
        # "s3://s-harmonai/datasets/samples_raw/samples_all/1/",
        # "s3://s-harmonai/datasets/samples_raw/samples_ms/1/",
        # "s3://s-laion-audio/webdataset_tar/freesound_no_overlap", 
        "s3://s-laion-audio/webdataset_tar/FMA_stereo/"
    ]

    print("Getting URL list...")
    urls = get_all_s3_urls(
        profiles={'s3://s-harmonai':'scott'},
        names=names, 
        s3_url_prefix=None,
        recursive=True, debug=False,
    )
    print("len(urls) =",len(urls))
Getting URL list...
len(urls) = 838
urls[:5]
['pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FMA_stereo/test - --profile default',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FMA_stereo/test/0.tar - --profile default',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FMA_stereo/test/1.tar - --profile default',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FMA_stereo/test/10.tar - --profile default',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FMA_stereo/test/11.tar - --profile default']

source

IterableAudioDataset

 IterableAudioDataset (paths, sample_rate=48000, sample_size=65536,
                       random_crop=True, load_frac=1.0,
                       cache_training_data=False, num_gpus=8,
                       redraw_silence=True, silence_thresh=-60,
                       max_redraws=2, augs='Stereo(), PhaseFlipper()',
                       verbose=False)

Iterable version of AudioDataset, used with Chain (below)

Type Default Details
paths list of strings of directory (/tree) names to draw audio files from
sample_rate int 48000 audio sample rate in Hz
sample_size int 65536 how many audio samples in each “chunk”
random_crop bool True take chunks from random positions within files
load_frac float 1.0 fraction of total dataset to load
cache_training_data bool False True = pre-load whole dataset into memory (not fully supported)
num_gpus int 8 used only when cache_training_data=True, to avoid duplicates,
redraw_silence bool True a chunk containing silence will be replaced with a new one
silence_thresh int -60 threshold in dB below which we declare to be silence
max_redraws int 2 when redrawing silences, don’t do it more than this many
augs str Stereo(), PhaseFlipper() list of augmentation transforms after PadCrop, as a string
verbose bool False whether to print notices of reasampling or not
from aeiou.viz import playable_spectrogram
iter_ds = IterableAudioDataset('/fsx/shawley/data/maestro', augs='Stereo(), NormInputs()')
assert isinstance(iter_ds, torch.utils.data.IterableDataset),"Nope"
try:
    sample = next(iter(iter_ds))
    playable_spectrogram(sample, specs='wave_mel', output_type='live')
except: pass
augs = Stereo(), NormInputs()
AudioDataset:1276 files found.
/fsx/shawley/envs_sm/aa/lib/python3.10/site-packages/torchaudio/transforms/_transforms.py:611: UserWarning: Argument 'onesided' has been deprecated and has no influence on the behavior of this module.
  warnings.warn(

AudioWebDataLoader

Uses WebDataset for audio files


source

wds_preprocess

 wds_preprocess (sample, sample_size=65536, sample_rate=48000,
                 verbose=False, random_crop=True, normalize_lufs=None,
                 metadata_prompt_funcs=None, force_channels='stereo',
                 augment_phase=True)

utility routine for QuickWebDataLoader, below. New version by Zach Evans, from https://github.com/zqevans/audio-diffusion/dataset.py. Old version in source, commented out


source

name_cache_file

 name_cache_file (url)

provides the filename to which to cache a url

AudioWebDataLoader class

Helper routines for AudioWebDataLoader (below).
source: Zach Evan’s audio-diffusion repo


source

is_valid_sample

 is_valid_sample (sample)

source: audio-diffusion repo


source

log_and_continue

 log_and_continue (exn)

Call in an exception handler to ignore any exception, isssue a warning, and continue. source: audio-diffusion repo

Here’s the main class itself


source

AudioWebDataLoader

 AudioWebDataLoader (names=['FSD50K'], subsets=[''],
                     s3_url_prefix='s3://s-laion-audio/webdataset_tar/',
                     profile='',
                     audio_file_ext='wav;flac;mp3;ogg;aiff;aif',
                     filter_str='tar', recursive=True, sample_size=65536,
                     sample_rate=48000, random_crop=True, num_workers=2,
                     prefetch_factor=10, batch_size=4, shuffle_vals=[1000,
                     10000], epoch_len=1000, debug=False, verbose=False,
                     callback=<function wds_preprocess>,
                     shuffle_urls=True, shuffle_seed=None, zachs=True,
                     **kwargs)

Sets up a WebDataLoader pipeline with some typical defaults for audio files

Type Default Details
names list [‘FSD50K’] names of datasets. will search all available s3 urls
subsets list [’’] list of subsets you want from those datasets, e.g. [‘train’,‘valid’]
s3_url_prefix str s3://s-laion-audio/webdataset_tar/ prefix for those dataset names
profile str AWS S3 profile string to pass in (default: none)
audio_file_ext str wav;flac;mp3;ogg;aiff;aif extension(s) of audio files; passed to wds.to_tuple
filter_str str tar only grab files with this substring
recursive bool True recursively list all tar files in all subdirs
sample_size int 65536 how long each sample to grab via PadCrop
sample_rate int 48000 standard sr in Hz
random_crop bool True take chunks from random positions within files
num_workers int 2 number of PyTorch DataLoaders
prefetch_factor int 10 number of batches to pre-fetch
batch_size int 4 typical batch size
shuffle_vals list [1000, 10000] values passed into shuffle as per WDS tutorials
epoch_len int 1000 how many passes/loads make for an epoch? wds part of this is not well documented IMHO
debug bool False print info on internal workings
verbose bool False unlike debug. this only prints in the callback
callback function wds_preprocess function to call for additional user-based processing
shuffle_urls bool True shuffle url list before it’s passed to WebDataset
shuffle_seed NoneType None seed for shuffling of urls
zachs bool True use zach’s data pipeline or hawley’s
kwargs

Let’s test this dataloader:

if False:
    train_dl = AudioWebDataLoader(names=['FSD50K'], num_workers=1, debug=False, verbose=False, batch_size=2, zachs=True)
    train_iter = iter(train_dl)
    audio_batch = next(train_iter)
    audio_batch = audio_batch[0].squeeze()
    print("audio_batch.shape = ",audio_batch.shape)
    sp = playable_spectrogram(audio_batch[0], specs='melspec', output_type='live')

    sp

And go again..

if False:
    audio_batch = next(train_iter)
    audio_batch = audio_batch[0].squeeze()
    print("audio_batch.shape = ",audio_batch.shape)
    sp = playable_spectrogram(audio_batch[0], specs='melspec', output_type='live')
    sp

Simple version: get_wds_loader

A simple routine for basic pulling of audio files


source

get_wds_loader

 get_wds_loader (batch_size, sample_size, names, s3_url_prefix=None,
                 sample_rate=48000, num_workers=8, recursive=True,
                 profiles={}, epoch_steps=1000, random_crop=True,
                 normalize_lufs=None, metadata_prompt_funcs=None,
                 force_channels='stereo', augment_phase=True)

Simpler loader from https://github.com/zqevans/audio-diffusion/dataset.py

Test code for get_wds_loader:

batch_size = 4
sample_size = 2**18
sample_rate = 48000
num_workers = 4
profiles = {}

dl = get_wds_loader(
        batch_size=batch_size,
        s3_url_prefix=None,
        sample_size=sample_size,
        names=names,
        sample_rate=sample_rate,
        num_workers=num_workers,
        recursive=True,
        random_crop=True,
        epoch_steps=1,
        profiles=profiles,
)
dl_iter = iter(dl)
batch = next(dl_iter) 
print("batch = ",batch) 
print("Success!")
download failed: s3://s-laion-audio/webdataset_tar/FMA_stereo/train/642.tar to - [Errno 32] Broken pipe
download failed: s3://s-laion-audio/webdataset_tar/FMA_stereo/train/82.tar to - [Errno 32] Broken pipe
download failed: s3://s-laion-audio/webdataset_tar/FMA_stereo/test/35.tar to - [Errno 32] Broken pipe
download failed: s3://s-laion-audio/webdataset_tar/FMA_stereo/train/269.tar to - [Errno 32] Broken pipe
batch =  [tensor([[[[ 0.4060,  0.3965,  0.3820,  ...,  0.5274,  0.5349,  0.5406],
          [ 0.5739,  0.5558,  0.5435,  ...,  0.5655,  0.5707,  0.5763]],

         [[-0.3591, -0.3761, -0.3705,  ..., -0.2486, -0.3406, -0.4426],
          [-0.2027, -0.1295, -0.0775,  ..., -0.3002, -0.3320, -0.3603]],

         [[-0.1131, -0.0873, -0.1157,  ...,  0.0386,  0.0382,  0.0378],
          [-0.2008, -0.1766, -0.1978,  ...,  0.0127,  0.0131,  0.0133]],

         [[ 0.1581,  0.0486, -0.0798,  ..., -0.0239, -0.1204, -0.1762],
          [ 0.0942,  0.0589,  0.0190,  ..., -0.1309, -0.1081, -0.0622]]]]), [{'text': ['playing song in album The Cult From Moon Mountain, titled Circle Moon, by Fursaxa, of which the genre is Folk, the language code is en, the composer is Tara Burke, the date created is 2008-11-26 02:13:59'], 'original_data': {'title': ['FMA: A Dataset For Music Analysis'], 'description': ['Free Music Archive (FMA), an open and easily accessible dataset suitable for evaluating several tasks in MIR, a field concerned with browsing, searching, and organizing large music collections.'], 'license': ['MIT License'], 'filename': ['000714.mp3'], 'genre': ['Folk'], 'album': ['The Cult From Moon Mountain'], 'song_title': ['Circle Moon'], 'artist': ['Fursaxa'], 'duration': tensor([338]), 'composer': ['Tara Burke'], 'date_recorded': ['2008-11-26 02:13:59'], 'language_code': ['en']}, 'seconds_start': tensor([69]), 'seconds_total': tensor([339]), 'prompt': ['playing song in album The Cult From Moon Mountain, titled Circle Moon, by Fursaxa, of which the genre is Folk, the language code is en, the composer is Tara Burke, the date created is 2008-11-26 02:13:59']}, {'text': ["playing song titled I'm Quitting, in album Hank IV Live at WFMU on Brian's Show on 11/18/2008, by Hank IV, of which the language code is en, the date created is 2008-12-04 20:08:12, the genre is Rock"], 'original_data': {'title': ['FMA: A Dataset For Music Analysis'], 'description': ['Free Music Archive (FMA), an open and easily accessible dataset suitable for evaluating several tasks in MIR, a field concerned with browsing, searching, and organizing large music collections.'], 'license': ['MIT License'], 'filename': ['003688.mp3'], 'genre': ['Rock'], 'album': ["Hank IV Live at WFMU on Brian's Show on 11/18/2008"], 'song_title': ["I'm Quitting"], 'artist': ['Hank IV'], 'duration': tensor([146]), 'composer': ['nan'], 'date_recorded': ['2008-12-04 20:08:12'], 'language_code': ['en']}, 'seconds_start': tensor([38]), 'seconds_total': tensor([147]), 'prompt': ["playing song titled I'm Quitting, in album Hank IV Live at WFMU on Brian's Show on 11/18/2008, by Hank IV, of which the language code is en, the date created is 2008-12-04 20:08:12, the genre is Rock"]}, {'text': ['playing song titled Rother, Dinger You and Me, by Antiguo Automata Mexicano, of which the date created is 2008-12-04 19:37:15, the language code is en, the genre is Electronic'], 'original_data': {'title': ['FMA: A Dataset For Music Analysis'], 'description': ['Free Music Archive (FMA), an open and easily accessible dataset suitable for evaluating several tasks in MIR, a field concerned with browsing, searching, and organizing large music collections.'], 'license': ['MIT License'], 'filename': ['003345.mp3'], 'genre': ['Electronic'], 'album': ['nan'], 'song_title': ['Rother, Dinger You and Me'], 'artist': ['Antiguo Automata Mexicano'], 'duration': tensor([312]), 'composer': ['nan'], 'date_recorded': ['2008-12-04 19:37:15'], 'language_code': ['en']}, 'seconds_start': tensor([150]), 'seconds_total': tensor([313]), 'prompt': ['playing song titled Rother, Dinger You and Me, by Antiguo Automata Mexicano, of which the date created is 2008-12-04 19:37:15, the language code is en, the genre is Electronic']}, {'text': ['playing song titled Poctoth, in album Mono M::P Free, by Thomas Dimuzio, of which the language code is en, the genre is Electronic, the date created is 2008-11-26 03:21:17'], 'original_data': {'title': ['FMA: A Dataset For Music Analysis'], 'description': ['Free Music Archive (FMA), an open and easily accessible dataset suitable for evaluating several tasks in MIR, a field concerned with browsing, searching, and organizing large music collections.'], 'license': ['MIT License'], 'filename': ['002074.mp3'], 'genre': ['Electronic'], 'album': ['Mono M::P Free'], 'song_title': ['Poctoth'], 'artist': ['Thomas Dimuzio'], 'duration': tensor([126]), 'composer': ['nan'], 'date_recorded': ['2008-11-26 03:21:17'], 'language_code': ['en']}, 'seconds_start': tensor([101]), 'seconds_total': tensor([127]), 'prompt': ['playing song titled Poctoth, in album Mono M::P Free, by Thomas Dimuzio, of which the language code is en, the genre is Electronic, the date created is 2008-11-26 03:21:17']}], [[tensor([0.2057], dtype=torch.float64), tensor([0.2218], dtype=torch.float64)], [tensor([0.2613], dtype=torch.float64), tensor([0.2987], dtype=torch.float64)], [tensor([0.4829], dtype=torch.float64), tensor([0.5004], dtype=torch.float64)], [tensor([0.7989], dtype=torch.float64), tensor([0.8419], dtype=torch.float64)]]]
Success!

Footnotes

  1. sic.↩︎