# ViT


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

## ViT components

used in both encoder and decoder.

------------------------------------------------------------------------

### RoPE2D

``` python

def RoPE2D(
    head_dim
):

```

*Base class for all neural network modules.*

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a
tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have
their parameters converted when you call :meth:`to`, etc.

.. note:: As per the example above, an `__init__()` call to the parent
class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or
evaluation mode. :vartype training: bool

``` python
# testing
head_dim = 768//4
x = torch.rand((2, 8, 256, head_dim))
rope = RoPE2D(head_dim)
rot_x = rope(x) 
print("rot_x.shape = ",rot_x.shape)
```

    rot_x.shape =  torch.Size([2, 8, 256, 192])

------------------------------------------------------------------------

### Attention

``` python

def Attention(
    dim, heads, dim_qkv:NoneType=None
):

```

*Base class for all neural network modules.*

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a
tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have
their parameters converted when you call :meth:`to`, etc.

.. note:: As per the example above, an `__init__()` call to the parent
class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or
evaluation mode. :vartype training: bool

``` python
# testing 
x = torch.rand(2, 256, 768)
attn = Attention(768, 8) 
a = attn(x) 
print("Done: a.shape = ",a.shape)

attn2 = Attention(768, 8, 64) 
a2 = attn2(x) 
print("Done: a2.shape = ",a2.shape)
```

    Done: a.shape =  torch.Size([2, 256, 768])
    Done: a2.shape =  torch.Size([2, 256, 768])

------------------------------------------------------------------------

### TransformerBlock

``` python

def TransformerBlock(
    dim, heads, dim_qkv:NoneType=None, hdim:NoneType=None
):

```

*Base class for all neural network modules.*

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a
tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have
their parameters converted when you call :meth:`to`, etc.

.. note:: As per the example above, an `__init__()` call to the parent
class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or
evaluation mode. :vartype training: bool

``` python
# testing
x = torch.randn(2, 256, 768) 
trans = TransformerBlock(768, 8) 
out = trans(x) 
print("out.shape = ",out.shape)
```

    out.shape =  torch.Size([2, 256, 768])

## Encoder

Does patch embedding and then some transformer blocks.

------------------------------------------------------------------------

### PatchEmbedding

``` python

def PatchEmbedding(
    in_channels:int=1, # 1 for solo piano, for midi PR's, = # of instruments
    patch_size:int=16, # assuming square patches, e.g. 16 implies 16x16
    dim:int=768, # embedding dimension
):

```

*Base class for all neural network modules.*

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a
tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have
their parameters converted when you call :meth:`to`, etc.

.. note:: As per the example above, an `__init__()` call to the parent
class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or
evaluation mode. :vartype training: bool

``` python
# testing
pe = PatchEmbedding()
x = torch.randn(2, 1, 256, 256)
z, non_empty, pos = pe(x) 
print("z.shape, non_empty.shape, pos.shape = ",z.shape, non_empty.shape, pos.shape) 
print("pos = \n",pos.tolist())
```

    z.shape, non_empty.shape, pos.shape =  torch.Size([2, 256, 768]) torch.Size([2, 256]) torch.Size([256, 2])
    pos = 
     [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9], [0, 10], [0, 11], [0, 12], [0, 13], [0, 14], [0, 15], [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], [1, 8], [1, 9], [1, 10], [1, 11], [1, 12], [1, 13], [1, 14], [1, 15], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], [2, 8], [2, 9], [2, 10], [2, 11], [2, 12], [2, 13], [2, 14], [2, 15], [3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5], [3, 6], [3, 7], [3, 8], [3, 9], [3, 10], [3, 11], [3, 12], [3, 13], [3, 14], [3, 15], [4, 0], [4, 1], [4, 2], [4, 3], [4, 4], [4, 5], [4, 6], [4, 7], [4, 8], [4, 9], [4, 10], [4, 11], [4, 12], [4, 13], [4, 14], [4, 15], [5, 0], [5, 1], [5, 2], [5, 3], [5, 4], [5, 5], [5, 6], [5, 7], [5, 8], [5, 9], [5, 10], [5, 11], [5, 12], [5, 13], [5, 14], [5, 15], [6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6], [6, 7], [6, 8], [6, 9], [6, 10], [6, 11], [6, 12], [6, 13], [6, 14], [6, 15], [7, 0], [7, 1], [7, 2], [7, 3], [7, 4], [7, 5], [7, 6], [7, 7], [7, 8], [7, 9], [7, 10], [7, 11], [7, 12], [7, 13], [7, 14], [7, 15], [8, 0], [8, 1], [8, 2], [8, 3], [8, 4], [8, 5], [8, 6], [8, 7], [8, 8], [8, 9], [8, 10], [8, 11], [8, 12], [8, 13], [8, 14], [8, 15], [9, 0], [9, 1], [9, 2], [9, 3], [9, 4], [9, 5], [9, 6], [9, 7], [9, 8], [9, 9], [9, 10], [9, 11], [9, 12], [9, 13], [9, 14], [9, 15], [10, 0], [10, 1], [10, 2], [10, 3], [10, 4], [10, 5], [10, 6], [10, 7], [10, 8], [10, 9], [10, 10], [10, 11], [10, 12], [10, 13], [10, 14], [10, 15], [11, 0], [11, 1], [11, 2], [11, 3], [11, 4], [11, 5], [11, 6], [11, 7], [11, 8], [11, 9], [11, 10], [11, 11], [11, 12], [11, 13], [11, 14], [11, 15], [12, 0], [12, 1], [12, 2], [12, 3], [12, 4], [12, 5], [12, 6], [12, 7], [12, 8], [12, 9], [12, 10], [12, 11], [12, 12], [12, 13], [12, 14], [12, 15], [13, 0], [13, 1], [13, 2], [13, 3], [13, 4], [13, 5], [13, 6], [13, 7], [13, 8], [13, 9], [13, 10], [13, 11], [13, 12], [13, 13], [13, 14], [13, 15], [14, 0], [14, 1], [14, 2], [14, 3], [14, 4], [14, 5], [14, 6], [14, 7], [14, 8], [14, 9], [14, 10], [14, 11], [14, 12], [14, 13], [14, 14], [14, 15], [15, 0], [15, 1], [15, 2], [15, 3], [15, 4], [15, 5], [15, 6], [15, 7], [15, 8], [15, 9], [15, 10], [15, 11], [15, 12], [15, 13], [15, 14], [15, 15]]

------------------------------------------------------------------------

### make_mae_mask

``` python

def make_mae_mask(
    non_empty, ratio:int=0, has_cls_token:bool=True
):

```

*Apply token masking for MAE training. 1=keep, 0=masked*

------------------------------------------------------------------------

### apply_mae_mask

``` python

def apply_mae_mask(
    x, pos, non_empty, mae_mask
):

```

*Apply token masking for MAE training. 1=keep, 0=masked*

------------------------------------------------------------------------

### ViTEncoder

``` python

def ViTEncoder(
    in_channels, # 1 for grayscale
    image_size, # tuple (H,W), e.g. (256, 256)
    patch_size, # assuming square patches, e.g 16
    dim, # embedding dim, e.g. 768
    depth, # number of transformerblock layers -- 4?
    heads, # number of attention heads - 8?
):

```

*Vision Transformer Encoder for piano rolls, keeps track of empty
patches (non_empty) and supports masking*

``` python
B, C, H, W = 4, 1, 128, 128
patch_size, dim, depth, heads = 16, 256, 2, 4 
x = torch.randn(B,C,H,W) 
encoder = ViTEncoder( C, (H,W), patch_size, dim, depth, heads) 
enc_out = encoder(x) 
print("CLS shape:", enc_out.patches[0].emb.shape)
print("patch shape:", enc_out.patches[1].emb[0].shape)
```

    CLS shape: torch.Size([4, 1, 256])
    patch shape: torch.Size([64, 256])

## Decoder

Like the Encoder, only instead of doing “PatchEmbedding” on the front
end “UnPatchify”

------------------------------------------------------------------------

### Unpatchify

``` python

def Unpatchify(
    out_channels:int=1, # 1 for solo piano, for midi PR's, = # of instruments
    image_size:tuple=(128, 128), # h,w for output image
    patch_size:int=16, # assuming square patches, e.g. 16 implies 16x16
    dim:int=768, # embedding dimension
):

```

*Take patches and assemble an image*

``` python
z = torch.randn([3, 64, dim]) 
unpatch = Unpatchify(dim=dim) # keep the defaults
img = unpatch(z) 
print("img.shape = ",img.shape)
```

    img.shape =  torch.Size([3, 1, 128, 128])

------------------------------------------------------------------------

### ViTDecoder

``` python

def ViTDecoder(
    out_channels, image_size, # tuple (H,W), e.g. (256, 256)
    patch_size, # assuming square patches, e.g 16
    dim, # embedding dim, e.g. 768
    depth:int=4, # number of transformerblock layers -- 4?
    heads:int=8, # number of attention heads - 8?
):

```

*Vision Transformer Decoder for piano rolls*

``` python
z = torch.randn(3, 65, dim)  # batch of 3, with CLS token
decoder = ViTDecoder(out_channels=1, image_size=(H,W), patch_size=16, dim=dim)
img = decoder(z)
print(img.shape)
```

    AttributeError: 'Tensor' object has no attribute 'patches'
    [0;31m---------------------------------------------------------------------------[0m
    [0;31mAttributeError[0m                            Traceback (most recent call last)
    Cell [0;32mIn[19], line 4[0m
    [1;32m      2[0m z [38;5;241m=[39m torch[38;5;241m.[39mrandn([38;5;241m3[39m, [38;5;241m65[39m, dim)  [38;5;66;03m# batch of 3, with CLS token[39;00m
    [1;32m      3[0m decoder [38;5;241m=[39m ViTDecoder(out_channels[38;5;241m=[39m[38;5;241m1[39m, image_size[38;5;241m=[39m(H,W), patch_size[38;5;241m=[39m[38;5;241m16[39m, dim[38;5;241m=[39mdim)
    [0;32m----> 4[0m img [38;5;241m=[39m [43mdecoder[49m[43m([49m[43mz[49m[43m)[49m
    [1;32m      5[0m [38;5;28mprint[39m(img[38;5;241m.[39mshape)

    File [0;32m~/envs/midi-rae/lib/python3.10/site-packages/torch/nn/modules/module.py:1776[0m, in [0;36mModule._wrapped_call_impl[0;34m(self, *args, **kwargs)[0m
    [1;32m   1774[0m     [38;5;28;01mreturn[39;00m [38;5;28mself[39m[38;5;241m.[39m_compiled_call_impl([38;5;241m*[39margs, [38;5;241m*[39m[38;5;241m*[39mkwargs)  [38;5;66;03m# type: ignore[misc][39;00m
    [1;32m   1775[0m [38;5;28;01melse[39;00m:
    [0;32m-> 1776[0m     [38;5;28;01mreturn[39;00m [38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43m_call_impl[49m[43m([49m[38;5;241;43m*[39;49m[43margs[49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mkwargs[49m[43m)[49m

    File [0;32m~/envs/midi-rae/lib/python3.10/site-packages/torch/nn/modules/module.py:1787[0m, in [0;36mModule._call_impl[0;34m(self, *args, **kwargs)[0m
    [1;32m   1782[0m [38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in[39;00m
    [1;32m   1783[0m [38;5;66;03m# this function, and just call forward.[39;00m
    [1;32m   1784[0m [38;5;28;01mif[39;00m [38;5;129;01mnot[39;00m ([38;5;28mself[39m[38;5;241m.[39m_backward_hooks [38;5;129;01mor[39;00m [38;5;28mself[39m[38;5;241m.[39m_backward_pre_hooks [38;5;129;01mor[39;00m [38;5;28mself[39m[38;5;241m.[39m_forward_hooks [38;5;129;01mor[39;00m [38;5;28mself[39m[38;5;241m.[39m_forward_pre_hooks
    [1;32m   1785[0m         [38;5;129;01mor[39;00m _global_backward_pre_hooks [38;5;129;01mor[39;00m _global_backward_hooks
    [1;32m   1786[0m         [38;5;129;01mor[39;00m _global_forward_hooks [38;5;129;01mor[39;00m _global_forward_pre_hooks):
    [0;32m-> 1787[0m     [38;5;28;01mreturn[39;00m [43mforward_call[49m[43m([49m[38;5;241;43m*[39;49m[43margs[49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mkwargs[49m[43m)[49m
    [1;32m   1789[0m result [38;5;241m=[39m [38;5;28;01mNone[39;00m
    [1;32m   1790[0m called_always_called_hooks [38;5;241m=[39m [38;5;28mset[39m()

    Cell [0;32mIn[18], line 16[0m, in [0;36mViTDecoder.forward[0;34m(self, enc_out, strip_cls_token)[0m
    [1;32m     15[0m [38;5;28;01mdef[39;00m[38;5;250m [39m[38;5;21mforward[39m([38;5;28mself[39m, enc_out, strip_cls_token[38;5;241m=[39m[38;5;28;01mTrue[39;00m):
    [0;32m---> 16[0m     z [38;5;241m=[39m [43menc_out[49m[38;5;241;43m.[39;49m[43mpatches[49m[38;5;241m.[39mall_emb          [38;5;66;03m# (B, 1+N, dim) — CLS + patches[39;00m
    [1;32m     17[0m     [38;5;28;01mfor[39;00m block [38;5;129;01min[39;00m [38;5;28mself[39m[38;5;241m.[39mblocks:  z [38;5;241m=[39m block(z)
    [1;32m     18[0m     [38;5;28;01mif[39;00m strip_cls_token: z [38;5;241m=[39m z[:,[38;5;241m1[39m:] 

    [0;31mAttributeError[0m: 'Tensor' object has no attribute 'patches'

------------------------------------------------------------------------

### LightweightMAEDecoder

``` python

def LightweightMAEDecoder(
    patch_size:int=16, dim:int=256, depth:int=6, heads:int=4
):

```

*Simple decoder for MAE pretraining - reconstructs masked patches* loss
should compare `output[:, ~mae_mask]` against original masked patch
pixels.
