core

Core data structures for midi-rae: PatchState and HierarchicalPatchState
print(logo)

          ▬▬    ▬▬▬    ▬▬                                       ▬▬                                  "Scrawl Me Maybe"   
                 ▬▬                                                                                                     
 ▬▬  ▬▬  ▬▬▬     ▬▬   ▬▬▬        ▬▬▬▬▬▬ ▬▬▬▬   ▬▬▬▬            ▬▬▬   ▬▬▬▬ ▬▬▬▬▬▬  ▬▬▬▬           ▬▬▬▬▬   ▬▬▬▬  ▬▬▬▬▬    
 ▬▬▬▬▬▬▬  ▬▬  ▬▬▬▬▬    ▬▬   ▬▬▬▬  ▬▬  ▬▬   ▬▬ ▬▬  ▬▬   ▬▬▬▬     ▬▬  ▬▬  ▬▬ ▬▬  ▬▬    ▬▬   ▬▬▬▬  ▬▬      ▬▬  ▬▬ ▬▬  ▬▬ 
 ▬▬ ▬ ▬▬  ▬▬ ▬▬  ▬▬    ▬▬         ▬▬    ▬▬▬▬▬ ▬▬▬▬▬▬            ▬▬  ▬▬▬▬▬▬ ▬▬  ▬▬ ▬▬▬▬▬          ▬▬▬▬▬  ▬▬  ▬▬ ▬▬  ▬▬   
 ▬▬   ▬▬  ▬▬ ▬▬  ▬▬    ▬▬         ▬▬   ▬▬  ▬▬ ▬▬            ▬▬  ▬▬  ▬▬     ▬▬▬▬▬ ▬▬  ▬▬              ▬▬ ▬▬  ▬▬ ▬▬  ▬▬ 
 ▬▬   ▬▬ ▬▬▬▬ ▬▬▬▬▬▬  ▬▬▬▬       ▬▬▬▬   ▬▬▬▬▬  ▬▬▬▬▬        ▬▬  ▬▬   ▬▬▬▬▬ ▬▬     ▬▬▬▬▬         ▬▬▬▬▬▬   ▬▬▬▬  ▬▬  ▬▬ 
                                                             ▬▬▬▬         ▬▬▬▬      

Overview

These dataclasses bundle patch embeddings with their spatial and mask metadata, replacing scattered positional return values and manual mask indexing throughout the codebase.

PatchState

Holds a set of patch embeddings at a single spatial scale, along with their grid positions and masks. Provides convenience properties for common operations like filtering visible patches.

HierarchicalPatchState

A list of PatchState objects ordered coarsest → finest (index 0 = global/CLS level). Currently used with two levels (CLS + patches), designed to extend to Swin-style multi-scale later.

EncoderOutput

Full encoder output bundling the hierarchical patch states with the full (pre-MAE-masking) positions and masks needed by the decoder for reconstruction.


PatchState


def PatchState(
    emb:Tensor, pos:Tensor, non_empty:Tensor, mae_mask:Tensor
)->None:

Bundle of patch embeddings at a single spatial scale with their metadata.

Attributes: emb: (B, N, dim) patch embeddings pos: (N, 2) grid coordinates (row, col) for each patch non_empty: (B, N) content mask — 1 where patch has content (e.g. notes), 0 for empty mae_mask: (N,) MAE visibility mask — 1=visible, 0=masked out for reconstruction


HierarchicalPatchState


def HierarchicalPatchState(
    levels:list
)->None:

Multi-scale patch states, ordered coarsest → finest (currently: [0]=CLS, [1]=spatial patches).

Attributes: levels: List of PatchState, one per scale

Note: enc_out.patches[i] and enc_out.patches.levels[i] are equivalent


to


def to(
    device
):

EncoderOutput


def EncoderOutput(
    patches:HierarchicalPatchState, full_pos:Tensor, full_non_empty:Tensor, mae_mask:Tensor
)->None:

Full encoder output.

Attributes: patches: Encoded representations (visible patches only) full_pos: (N_full, 2) all grid positions before MAE masking (needed by decoder) full_non_empty: (B, N_full) all content masks before MAE masking mae_mask: (N_full,) the MAE mask applied (1=visible, 0=masked)

Sample usage

Encoder returns EncoderOutput containing a HierarchicalPatchState:

enc_out = encoder(img, mask_ratio=0.5)

# Access the patch hierarchy
cls_state = enc_out.patches.coarsest    # PatchState with CLS token
patch_state = enc_out.patches.finest    # PatchState with patch embeddings

Working with PatchState — filtering, shapes, masks:

ps = enc_out.patches.finest

ps.emb          # (B, N_visible, dim) — patch embeddings
ps.pos          # (N_visible, 2) — grid coordinates (row, col)
ps.non_empty    # (B, N_visible) — content mask (1=has notes)
ps.mae_mask     # (N_visible,) — all True for already-filtered patches
ps.dim          # embedding dimension
ps.num_patches  # number of patches

vis = ps.visible  # new PatchState with only MAE-visible patches

In compute_batch_loss (encoder training):

# BEFORE: 8 positional return values
# loss_dict, z1, z2, non_emptys, pos2, mae_mask2, num_tokens, recon_patches = ...

# AFTER:
loss_dict, enc_out1, enc_out2, recon_patches = compute_batch_loss(...)
non_emptys = (enc_out1.patches.finest.non_empty, enc_out2.patches.finest.non_empty)

In LightweightMAEDecoder:

# BEFORE:
# def forward(self, z, pos_full, mae_mask): ...

# AFTER:
# def forward(self, enc_out: EncoderOutput): ...
#   — gets visible embeddings, full positions, and mae_mask all from enc_out

Future Swin hierarchy (coarsest → finest):

# levels[0] = global (like CLS), levels[1] = 4x4, levels[2] = 8x8, levels[3] = 16x16
h = enc_out.patches
h.coarsest          # global summary
h.finest            # finest-resolution patches  
h.levels[1]         # intermediate scale
h.levels[1].visible # visible patches at that scale