swin

Swin Transformer V2 Encoder for midi-rae — drop-in replacement for ViTEncoder

Design Overview

Motivation: Multi-Scale Understanding

We’re not doing mere image segmentation. The representations at different scales may be very different from one another and mapped non-linearly. Consider: the top-level representation could be something about musical genre. The next level down could be something about what part of the song we’re in. The next level down could be which part of the verse or chorus we’re in. The next level could be a given musical phrase. The next level could be the individual notes in the phrase.

Same Info, Organized Differently

ViT: The default ViT we’ve been using has 65 patches with 256 dimensions each: 65 * 256 = 16,640 encoding parameters.

For the a Swin with 6 stages, embed_dim=8, patch_h=patch_w=4, so the finest grid is 32×32:

Level Patch Size (px) Grid Dim Scalars
0 (coarsest) 128x128 1×1 256 256
1 64x64 2×2 128 512
2 32x32 4×4 64 1,024
3 16x16 8×8 32 2,048
4 8x8 16×16 16 4,096
5 (finest) 4x4 32x32 8 8,192
Total 16,128 encoding parameters

…So the Swin has basically the same amount of information as the ViT (actually slightly less!), it’s just organized differently.

What this module does

SwinEncoder is a drop-in replacement for ViTEncoder that uses the Swin Transformer V2 architecture. It takes a piano roll image (B, 1, 128, 128) and returns an EncoderOutput with hierarchical multi-scale patch states.

Why Swin V2?

  • Hierarchical representation: 7 levels from finest (64×64 grid, dim=4) down to a single CLS-like token (1×1, dim=256), compared to ViT’s flat single-scale output
  • Efficient attention: Windowed attention with shifted windows — O(N) instead of O(N²)
  • V2 improvements: Cosine attention with learned log-scale temperature, continuous position bias via CPB MLP, res-post-norm for training stability

Architecture

Stage Grid Patch covers Dim Depths Heads
0 64×64 2×2 4 1 1
1 32×32 4×4 8 1 1
2 16×16 8×8 16 2 1
3 8×8 16×16 32 2 2
4 4×4 32×32 64 6 4
5 2×2 64×64 128 2 8
6 1×1 128×128 256 1 16

Config is in configs/config_swin.yaml.

Implementation approach

We use timm’s SwinTransformerV2Stage directly — no copied or modified Swin internals. Our SwinEncoder wrapper handles only:

  1. Patch embeddingConv2d(1, 4, kernel_size=2, stride=2) + LayerNorm
  2. Empty patch detection — patches where all pixels are black get a learnable empty_token
  3. MAE masking (SimMIM-style) — masked patches get a learnable mask_token, grid stays intact so windowed attention works unmodified. Two-rate sampling: non-empty patches masked at mask_ratio, empty patches at mask_ratio × empty_mask_ratio (default 5%)
  4. Hierarchical output — collects each stage’s output into HierarchicalPatchState (coarsest-first), packaged as EncoderOutput

Key differences from ViTEncoder

  • No CLS token (stage 6’s single 1×1 token serves as a global summary)
  • No RoPE (Swin V2 uses its own continuous position bias)
  • MAE masking keeps all tokens (SimMIM-style) — no compute savings but preserves spatial grid
  • empty_mask_ratio controls how often trivial-to-reconstruct empty patches are masked

TODOs

  • HierarchicalPatchState could store window_size per level
  • EncoderOutput could store scale metadata (downsample factors per level)

Future: advanced masking & representation learning

  • Inter-stage patch masking: dropout-style masking between encoder stages, tapered ratio per stage (mae_ratio / 2**stage_idx), using a learnable mask token. Forces robust representations at every scale.
  • Self-distillation (DINO/iBOT-style): EMA teacher provides latent targets at all scales, eliminating need for pixel-level reconstruction at coarser levels.
  • Multi-scale reconstruction losses: reconstruct downsampled piano rolls at each hierarchy level (requires bidirectional decoder, e.g. U-Net with skip connections).

SwinEncoder


def SwinEncoder(
    img_height:int, # Input image height in pixels (e.g. 128)
    img_width:int, # Input image width in pixels (e.g. 128)
    patch_h:int=4, # Patch height for initial embedding
    patch_w:int=4, # Patch width for initial embedding
    in_chans:int=1, # Number of input channels (1 for piano roll)
    embed_dim:int=8, # Base embedding dimension (doubles each stage)
    depths:tuple=(1, 2, 2, 6, 2, 1), # Number of transformer blocks per stage
    num_heads:tuple=(1, 1, 2, 4, 8, 16), # Attention heads per stage
    window_size:int=8, # Window size for windowed attention
    mlp_ratio:float=4.0, # MLP hidden dim = embed_dim * mlp_ratio
    qkv_bias:bool=True, # Add bias to QKV projections
    drop_rate:float=0.0, # Dropout after patch embedding
    proj_drop_rate:float=0.0, # Dropout after attention projection
    attn_drop_rate:float=0.0, # Dropout on attention weights
    drop_path_rate:float=0.1, # Stochastic depth rate
    norm_layer:type=LayerNorm, # Normalization layer class
    mae_ratio:float=0.0, # Fraction of non-empty patches to mask (0=no masking)
    empty_mask_ratio:float=0.05, # Mask rate for empty patches relative to mae_ratio
    squash_coarse:bool=False, # Overwrite course level empty patch tokens with a learned empty patch token.
):

Swin Transformer V2 Encoder for midi-rae — drop-in replacement for ViTEncoder. (Wrapper for timm routines)

# Test: verify SwinEncoder output shapes
B, C, H, W = 2, 1, 128, 128
enc = SwinEncoder(img_height=H, img_width=W)
x = torch.randn(B, C, H, W)
out = enc(x)

print(f'mae_mask:        {out.mae_mask.shape}')
print(f'full_pos:        {out.full_pos.shape}')
print(f'full_non_empty:  {out.full_non_empty.shape}')
print(f'num levels:      {len(out.patches.levels)}')
for i, ps in enumerate(out.patches.levels):
    g = int(ps.pos.shape[0]**0.5)
    p = H // g
    print(f'  level {i}: emb={ps.emb.shape}, pos={ps.pos.shape}  — grid {g}×{g} ({p}×{p} patch{"es" if ps.emb.shape[1]>1 else ""})')

# Expected hierarchy (coarsest first), 128×128 image, 2×2 patches:
#   level 0 (coarsest): emb=(1, 1,    256) — grid 1×1  (CLS-like)
#   level 1:            emb=(1, 4,    128) — grid 2×2
#   level 2:            emb=(1, 16,    64) — grid 4×4
#   level 3:            emb=(1, 64,    32) — grid 8×8
#   level 4:            emb=(1, 256,   16) — grid 16×16
#   level 5:            emb=(1, 1024,   8) — grid 32×32
#   level 6 (finest):   emb=(1, 4096,   4) — grid 64×64
mae_mask:        torch.Size([2, 1024])
full_pos:        torch.Size([1024, 2])
full_non_empty:  torch.Size([2, 1024])
num levels:      6
  level 0: emb=torch.Size([2, 1, 256]), pos=torch.Size([1, 2])  — grid 1×1 (128×128 patch)
  level 1: emb=torch.Size([2, 4, 128]), pos=torch.Size([4, 2])  — grid 2×2 (64×64 patches)
  level 2: emb=torch.Size([2, 16, 64]), pos=torch.Size([16, 2])  — grid 4×4 (32×32 patches)
  level 3: emb=torch.Size([2, 64, 32]), pos=torch.Size([64, 2])  — grid 8×8 (16×16 patches)
  level 4: emb=torch.Size([2, 256, 16]), pos=torch.Size([256, 2])  — grid 16×16 (8×8 patches)
  level 5: emb=torch.Size([2, 1024, 8]), pos=torch.Size([1024, 2])  — grid 32×32 (4×4 patches)

Testing code to check for non-empty patches: Green equals non-empty, red equals empty

import matplotlib.pyplot as plt
import torch.nn.functional as F
from midi_rae.data import PRPairDataset

# Load one image from the dataset
ds = PRPairDataset(split='val')
img_tensor = ds[0]['img1'][:1]
x = img_tensor.unsqueeze(0)

# Run empty patch detection
enc = SwinEncoder(img_height=128, img_width=128)
non_empty = enc._compute_non_empty(x)
ne = non_empty[0].reshape(1, 1, 32, 32).float()

# Build hierarchy via max-pool cascade
levels = [ne[0, 0].cpu()]  # 64×64
while levels[-1].shape[0] > 1:
    ne = F.max_pool2d(ne, 2)
    levels.append(ne[0, 0].cpu())

# Plot: original image + all levels
fig, axes = plt.subplots(1, len(levels) + 1, figsize=(16, 2.7))
axes[0].imshow(img_tensor[0].cpu(), cmap='gray', origin='lower', aspect='auto')
axes[0].set_title('Original\n128×128 px')
for i, grid in enumerate(levels):
    g = grid.shape[0]
    axes[i+1].imshow(grid.numpy(), cmap='RdYlGn', origin='lower', aspect='auto', vmin=0, vmax=1)
    p = 128 // g
    axes[i+1].set_title(f'Level {i}\n{g}×{g} ({p}×{p}px)')
    axes[i+1].axis('off')
axes[0].axis('off')
plt.tight_layout()
plt.show()
Loading 91 val files from ~/datasets/POP909_images_basic... Finished loading.

Let’s compile data on the relative frequency of empty patches as a function of level, so we’ll measure that here.

import torch, torch.nn.functional as F, numpy as np, matplotlib.pyplot as plt
from midi_rae.data import PRPairDataset
from midi_rae.swin import SwinEncoder

ds = PRPairDataset(split='train')
enc = SwinEncoder(img_height=128, img_width=128)
gh, gw = enc.grid_size

N = min(5000, len(ds))
nlevels = int(np.log(gh)/np.log(2)+1)
print("gh = ",gh, ", nlevels = ",nlevels) 

# Collect non-empty fractions per level
#level_fracs = {i: [] for i in range(nlevels)}   # TODO: make this a list instead of a level-keyed dict
level_fracs = [[] for _ in range(nlevels)]


for idx in range(N):
    img = ds[idx]['img2'][:1].unsqueeze(0)  # (1,1,128,128)
    ne = enc._compute_non_empty(img)[0].reshape(1, 1, gh, gw).float()

    for lvl in range(nlevels):
        g = ne.shape[-1]
        frac = ne[0, 0].mean().item()  # fraction non-empty
        level_fracs[nlevels - 1 - lvl].append(frac)
        if g > 1: ne = F.max_pool2d(ne, 2)

# Summary table
print(f"\nStats on non-empty patches per level:")
print(f"{'Level':>6} {'Grid':>6} {'Patch':>8} {'Mean%':>7} {'Std%':>6} {'Min%':>6} {'Max%':>6}")
for lvl in range(nlevels):
    g = 2**lvl       # 1, 2, 4, 8, 16, 32
    p = 128 // g
    arr = np.array(level_fracs[lvl]) * 100
    print(f"{lvl:>6} {f'{g}x{g}':>6} {f'{p}x{p}':>8} {arr.mean():>6.1f}% {arr.std():>5.1f}% {arr.min():>5.1f}% {arr.max():>5.1f}%")

# graph inv mean freq per level for empties
from scipy.optimize import curve_fit

data = [1/(1-np.mean(level_fracs[i])) for i in range(nlevels)]
c = 1.1
fit_fn = lambda x, a, b: c + a * np.exp(b * x)
(a, b), _ = curve_fit(fit_fn, np.arange(nlevels), data)
fit = fit_fn(np.arange(nlevels), a, b)
print(f"a, b = {a}, {b}")
plt.figure(figsize=(4,3))
plt.semilogy(data, label='data')
plt.semilogy(fit, label=f'fit: {c} + {a:.2f}·exp({b:.2f}·x)')
npatches = 2**np.arange(nlevels)
print('npatches =',npatches) 
plt.semilogy(npatches[::-1], label='npatches[::-1]')
emb_dims = 8 * 2**np.arange(nlevels)[::-1]
print("emb_dims = ",emb_dims) 
plt.semilogy(np.sqrt(emb_dims)/np.sqrt(emb_dims[-1]), label=' ~ sqrt(D)')
plt.xlabel("Level (Coarsest -> Finest)")
plt.ylabel("Inverse empty frequency")
plt.legend()
plt.show()


# Histograms
fig, axes = plt.subplots(int(np.ceil(nlevels/3)), 3, figsize=(14, 6))
axes = axes.flat
for lvl in range(nlevels):
    g = 2**lvl
    p = 128 // g
    axes[lvl].hist(np.array(level_fracs[lvl]) * 100, bins=30, edgecolor='black', alpha=0.7)
    axes[lvl].set_title(f'Level {lvl}: {g}×{g} grid ( {p}x{p} px) ')
    axes[lvl].set_xlabel('% non-empty')
plt.suptitle(f'Non-empty patch fraction across {N} images', fontsize=13)
plt.tight_layout()
plt.show()
Loading 818 train files from ~/datasets/POP909_images_basic... Finished loading.
gh =  32 , nlevels =  6

Stats on non-empty patches per level:
 Level   Grid    Patch   Mean%   Std%   Min%   Max%
     0    1x1  128x128   99.9%   3.7%   0.0% 100.0%
     1    2x2    64x64   94.8%  14.2%   0.0% 100.0%
     2    4x4    32x32   49.2%   7.8%   0.0%  75.0%
     3    8x8    16x16   35.8%   6.0%   0.0%  62.5%
     4  16x16      8x8   26.3%   4.4%   0.0%  50.0%
     5  32x32      4x4   16.3%   3.5%   0.0%  43.8%
a, b = 713.1854373195413, -3.6658490988183505
npatches = [ 1  2  4  8 16 32]
emb_dims =  [256 128  64  32  16   8]


PixelShuffleHead


def PixelShuffleHead(
    out_channels:int=1, fpn_dim:int=64, hidden_ch:int=64, patch_size:int=4, grid_h:int=32, grid_w:int=32
):

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool


SwinMAEDecoder


def SwinMAEDecoder(
    patch_size:int=4, dims:tuple=(256, 128, 64, 32, 16, 8), fpn_dim:int=64, depth:int=2, heads:int=4
):

FPN-style MAE decoder for SwinEncoder hierarchical output. Top-down pathway fuses all levels, reconstructs at finest scale.


SwinDecoder


def SwinDecoder(
    img_height:int=128, # Output image height
    img_width:int=128, # Output image width
    patch_h:int=4, # Patch height (must match encoder)
    patch_w:int=4, # Patch width (must match encoder)
    out_channels:int=1, # Output channels (1 for piano roll)
    embed_dim:int=8, # Base embedding dim (same as encoder)
    depths:tuple=(1, 2, 2, 6, 2, 1), # Encoder depths (finest→coarsest); reversed internally
    num_heads:tuple=(1, 1, 2, 4, 8, 16), # Encoder heads (finest→coarsest); reversed internally
    window_size:int=8, # Window size for windowed attention
    mlp_ratio:float=4.0, # MLP hidden dim = dim * mlp_ratio
    qkv_bias:bool=True, # Bias in QKV projections
    drop_path_rate:float=0.1, # Stochastic depth rate
    proj_drop_rate:float=0.0, # Dropout after attention projection
    attn_drop_rate:float=0.0, # Dropout on attention weights
    norm_layer:type=LayerNorm
):

Swin V2 Decoder for midi-rae — symmetric multi-stage decoder.

Mirrors the encoder architecture: processes coarsest→finest with Swin V2 windowed attention at every spatial scale, fusing encoder skip connections via lateral projections at each level.

Pass the same config values (embed_dim, depths, num_heads) as the encoder; they are reversed internally for the coarsest→finest decode direction.

Takes EncoderOutput directly (same interface as SwinMAEDecoder).

TODO: Try ConvTranspose2d or PixelShuffle as alternatives to linear unpatchify. TODO: Make the unpatchify head swappable via a factory or argument.


PatchExpand


def PatchExpand(
    in_dim, out_dim, norm_layer:type=LayerNorm
):

Inverse of patch merging: doubles spatial resolution via learned linear expansion. (B, H, W, C_in) → (B, 2H, 2W, C_out)

Test: verify SwinDecoder output shapes:

from midi_rae.utils import * 

B, C, H, W = 3, 1, 128, 128
depths, num_heads = (1,2,2,6,2,1), (1,1,2,4,8,16)
enc = SwinEncoder(img_height=H, img_width=W, patch_h=4, patch_w=4,
                  embed_dim=8, depths=depths, num_heads=num_heads)
print(f"enc model parameters: {param_count(enc)[1]:,}")
dec = SwinDecoder(img_height=H, img_width=W, patch_h=4, patch_w=4,
                  embed_dim=8, depths=depths, num_heads=num_heads)
print(f"dec model parameters: {param_count(dec)[1]:,}")

x = torch.randn(B, C, H, W)
enc_out = enc(x)
recon = dec(enc_out)

print(f'Input:  {x.shape}')
print(f'Output: {recon.shape}')
assert recon.shape == x.shape, f'Shape mismatch: {recon.shape} != {x.shape}'
print('✓ Shapes match!')

enc_params = sum(p.numel() for p in enc.parameters())
dec_params = sum(p.numel() for p in dec.parameters())
print(f'Encoder params: {enc_params:,}')
print(f'Decoder params: {dec_params:,}')
enc model parameters: 1,748,135
dec model parameters: 1,035,735
Input:  torch.Size([3, 1, 128, 128])
Output: torch.Size([3, 1, 128, 128])
✓ Shapes match!
Encoder params: 1,748,135
Decoder params: 1,035,735

SwinMaskedEmbeddingPredictor


def SwinMaskedEmbeddingPredictor(
    dims:tuple=(256, 128, 64, 32, 16, 8), summary_dim:int=128, n_summaries:NoneType=None, mix_depth:int=2,
    heads:int=4, mask_ratio:float=0.4, mr_level_fac:float=1.25
):

Perceiver-style hierarchical embedding predictor (I-JEPA-inspired).

Summarizes each hierarchy level into bottleneck tokens via cross-attention, mixes across levels with self-attention, then predicts target embeddings at all positions via reverse cross-attention. Loss should be computed externally on masked positions only, e.g.: preds, masks = predictor(enc_out_context) targets = enc_out_target.patches.levels loss = sum(((p[~m] - t.emb[~m])**2).mean() for p, m, t in zip(preds, masks, targets))

Flow: Mask → Summarize → Mix → Predict

Test that:

from midi_rae.utils import param_count
from midi_rae.losses import safe_mean

mep = SwinMaskedEmbeddingPredictor()
print(f"mep model parameters: {param_count(mep)[1]:,}")

enc_out1 = enc_out # for main code, it'll be enc_out1 that will be the target 
z1 = [lvl.emb for lvl in enc_out1.patches.levels]

x2 = torch.randn(B, C, H, W)
enc_out2 = enc(x2) # for main code, it'll be enc_out2 that will get trained on

emb_pred, masks = mep(enc_out2)  # predict from masked-input enc_out2, what the comparable embeddings are in enc_out1
print(f"len(emb_pred) = {len(emb_pred)};   (B, N, D)")

with torch.no_grad(): # just for testing/demo, no grads 
    mep_loss = 0
    for lev, (emb, mask) in enumerate(zip(emb_pred, masks)): 
        print(f"Level {lev}: emb.shape = {tuple(emb.shape)}, mask.shape = {mask.shape}, avg # visible = {mask.sum()/B:.1f}, avg # masked = {(~mask).sum()/B:.1f}")
        target = enc_out1.patches[lev].emb
        mep_loss = mep_loss + safe_mean((emb[~mask] - target[~mask])**2)  # mask=0 means masked ,  mean across each level
    mep_loss = mep_loss / enc_out.patches.num_levels  # average over levels
print("mep_loss = ",mep_loss)
mep model parameters: 1,337,200
len(emb_pred) = 6;   (B, N, D)
Level 0: emb.shape = (3, 1, 256), mask.shape = torch.Size([3, 1]), avg # visible = 0.7, avg # masked = 0.3
Level 1: emb.shape = (3, 4, 128), mask.shape = torch.Size([3, 4]), avg # visible = 3.7, avg # masked = 0.3
Level 2: emb.shape = (3, 16, 64), mask.shape = torch.Size([3, 16]), avg # visible = 13.3, avg # masked = 2.7
Level 3: emb.shape = (3, 64, 32), mask.shape = torch.Size([3, 64]), avg # visible = 45.0, avg # masked = 19.0
Level 4: emb.shape = (3, 256, 16), mask.shape = torch.Size([3, 256]), avg # visible = 173.3, avg # masked = 82.7
Level 5: emb.shape = (3, 1024, 8), mask.shape = torch.Size([3, 1024]), avg # visible = 605.3, avg # masked = 418.7
mep_loss =  tensor(0.9994)