Here we take Zach Mueller's CAMVID Segmentation Tutorial and try to segment our fake-cyclegan data via 'standard' classification

 
import torch, re 
tv, cv = torch.__version__, torch.version.cuda
tv = re.sub('\+cu.*','',tv)
TORCH_VERSION = 'torch'+tv[0:-1]+'0'
CUDA_VERSION = 'cu'+cv.replace('.','')

print(f"TORCH_VERSION={TORCH_VERSION}; CUDA_VERSION={CUDA_VERSION}")
print(f"CUDA available = {torch.cuda.is_available()}, Device count = {torch.cuda.device_count()}, Current device = {torch.cuda.current_device()}")
print(f"Device name = {torch.cuda.get_device_name()}")
print("hostname:")
!hostname
TORCH_VERSION=torch1.9.0; CUDA_VERSION=cu111
CUDA available = True, Device count = 1, Current device = 0
Device name = GeForce RTX 3080
hostname:
bengio

This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, wwf, and espiownage currently running at the time of writing this:

  • fastai : 2.5.2
  • fastcore : 1.3.26
  • wwf : 0.0.16
  • espiownage : 0.0.36

Libraries

from fastai.vision.all import *
from espiownage.core import *

Below you will find the exact imports for everything we use today

from fastcore.xtras import Path

from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos

from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize

from fastai.layers import Mish
from fastai.losses import BaseLoss
from fastai.optimizer import ranger

from fastai.torch_core import tensor

from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner

from PIL import Image
import numpy as np

from torch import nn
from torchvision.models.resnet import resnet34

import torch
import torch.nn.functional as F

Dataset

path = Path('/home/drscotthawley/datasets/espiownage-cyclegan/')
#path = untar_data('https://anonymized.machine.com/~drscotthawley/espiownage-cyclegan.tgz')

Our validation set is inside a text document called valid.txt and split by new lines. Let's read it in: Nope we'll just use RandomSplitter for now.

Let's look at an image and see how everything aligns up

path_im = path/'images'
maskdir = 'masks_multiclass_1'
path_lbl = path/maskdir

First we need our filenames

import glob 
#fnames = get_image_files(path_im)
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)
len(meta_names), len(fnames), len(lbl_names)
(1200, 1200, 1200)

And now let's work with one of them

img_fn = fnames[2]
print(img_fn)
img = PILImage.create(img_fn)
img.show(figsize=(5,5))
/home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000002.png
<AxesSubplot:>

Now let's grab our y's. They live in the labels folder and are denoted by a _P

get_msk = lambda o: path/maskdir/f'{o.stem}_P{o.suffix}'

The stem and suffix grab everything before and after the period respectively.

Our masks are of type PILMask and we will make our gradient percentage (alpha) equal to 1 as we are not overlaying this on anything yet

msk_name = get_msk(img_fn)
print(msk_name)
msk = PILMask.create(msk_name)
msk.show(figsize=(5,5), alpha=1)
/home/drscotthawley/datasets/espiownage-cyclegan/masks_multiclass_1/steelpan_0000002_P.png
<AxesSubplot:>

Now if we look at what our mask actually is, we can see it's a giant array of pixels:

print(tensor(msk))
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.uint8)

And just make sure that it's a simple file and not antialiasing. Let's see what values it contains:

set(np.array(msk).flatten())
{0, 1, 4, 6}

Where each one represents a class that we can find in codes.txt. Let's make a vocabulary with it

#colors = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110]
colors = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]  # generated these with a bin size of 1 ring
#colors = list(set(np.array(msk).flatten()))
codes = [str(n) for n in range(len(colors))]; codes
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11']

We need a split function that will split from our list of valid filenames we grabbed earlier. Let's try making our own.

This takes in our filenames, and checks for all of our filenames in all of our items in our validation filenames

Transfer Learning between DataSets

Jeremy popularized the idea of image resizing:

  • Train on smaller sized images
  • Eventually get larger and larger
  • Transfer Learning loop

This first round we will train at half the image size

sz = msk.shape; sz
(384, 512)
half = tuple(int(x/2) for x in sz); half
(192, 256)
cyclegan = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
    get_items=get_image_files,
    splitter=RandomSplitter(),
    get_y=get_msk,
    batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])
dls = cyclegan.dataloaders(path/'images', fnames=fnames, bs=4)

Let's look at a batch, and look at all the classes between codes 1 and 30 (ignoring Animal and Wall)

dls.show_batch(max_n=4, vmin=1, vmax=len(codes), figsize=(14,10))

Lastly let's make our vocabulary a part of our DataLoaders, as our loss function needs to deal with the Void label

dls.vocab = codes

Now we need a methodology for grabbing that particular code from our output of numbers. Let's make everything into a dictionary

name2id = {v:int(v) for k,v in enumerate(codes)}
name2id
{'0': 0,
 '1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 '10': 10,
 '11': 11}

Awesome! Let's make an accuracy function

void_code = name2id['0'] # name2id['Void']

For segmentation, we want to squeeze all the outputted values to have it as a matrix of digits for our segmentation mask. From there, we want to match their argmax to the target's mask for each pixel and take the average

def acc_camvid(inp, targ):
  targ = targ.squeeze(1)
  mask = targ != void_code
  return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()

The Dynamic Unet

U-Net allows us to look at pixel-wise representations of our images through sizing it down and then blowing it bck up into a high resolution image. The first part we call an "encoder" and the second a "decoder"

On the image, the authors of the UNET paper describe the arrows as "denotions of different operations"

We have a special unet_learner. Something new is we can pass in some model configurations where we can declare a few things to customize it with!

  • Blur/blur final: avoid checkerboard artifacts
  • Self attention: A self-attention layer
  • y_range: Last activations go through a sigmoid for rescaling
  • Last cross - Cross-connection with the direct model input
  • Bottle - Bottlenck or not on that cross
  • Activation function
  • Norm type

Let's make a unet_learner that uses some of the new state of the art techniques. Specifically:

  • Self-attention layers: self_attention = True
  • Mish activation function: act_cls = Mish

Along with this we will use the Ranger as optimizer function.

opt = ranger
learn = unet_learner(dls, resnet34, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt)
/home/drscotthawley/envs/espi/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
learn.summary()
DynamicUnet (Input shape: 4)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     4 x 64 x 96 x 128   
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
____________________________________________________________________________
                     4 x 128 x 24 x 32   
Conv2d                                    73728      False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    8192       False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
____________________________________________________________________________
                     4 x 256 x 12 x 16   
Conv2d                                    294912     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    32768      False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
____________________________________________________________________________
                     4 x 512 x 6 x 8     
Conv2d                                    1179648    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Conv2d                                    131072     False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     4 x 1024 x 6 x 8    
Conv2d                                    4719616    True      
Mish                                                           
____________________________________________________________________________
                     4 x 512 x 6 x 8     
Conv2d                                    4719104    True      
Mish                                                           
____________________________________________________________________________
                     4 x 1024 x 6 x 8    
Conv2d                                    525312     True      
Mish                                                           
PixelShuffle                                                   
BatchNorm2d                               512        True      
Conv2d                                    2359808    True      
Mish                                                           
Conv2d                                    2359808    True      
Mish                                                           
Mish                                                           
____________________________________________________________________________
                     4 x 1024 x 12 x 16  
Conv2d                                    525312     True      
Mish                                                           
PixelShuffle                                                   
BatchNorm2d                               256        True      
Conv2d                                    1327488    True      
Mish                                                           
Conv2d                                    1327488    True      
Mish                                                           
____________________________________________________________________________
                     4 x 48 x 768        
Conv1d                                    18432      True      
Conv1d                                    18432      True      
Conv1d                                    147456     True      
Mish                                                           
____________________________________________________________________________
                     4 x 768 x 24 x 32   
Conv2d                                    295680     True      
Mish                                                           
PixelShuffle                                                   
BatchNorm2d                               128        True      
Conv2d                                    590080     True      
Mish                                                           
Conv2d                                    590080     True      
Mish                                                           
Mish                                                           
____________________________________________________________________________
                     4 x 512 x 48 x 64   
Conv2d                                    131584     True      
Mish                                                           
PixelShuffle                                                   
BatchNorm2d                               128        True      
____________________________________________________________________________
                     4 x 96 x 96 x 128   
Conv2d                                    165984     True      
Mish                                                           
Conv2d                                    83040      True      
Mish                                                           
Mish                                                           
____________________________________________________________________________
                     4 x 384 x 96 x 128  
Conv2d                                    37248      True      
Mish                                                           
PixelShuffle                                                   
ResizeToOrig                                                   
MergeLayer                                                     
Conv2d                                    88308      True      
Mish                                                           
Conv2d                                    88308      True      
Sequential                                                     
Mish                                                           
____________________________________________________________________________
                     4 x 12 x 192 x 256  
Conv2d                                    1200       True      
ToTensorBase                                                   
____________________________________________________________________________

Total params: 41,406,488
Total trainable params: 20,138,840
Total non-trainable params: 21,267,648

Optimizer used: <function ranger at 0x7f2197beb160>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback

If we do a learn.summary we can see this blow-up trend, and see that our model came in frozen. Let's find a learning rate

learn.lr_find()
/home/drscotthawley/envs/espi/lib/python3.8/site-packages/fastai/callback/schedule.py:269: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "ro" (-> color='r'). The keyword argument will take precedence.
  ax.plot(val, idx, 'ro', label=nm, c=color)
SuggestedLRs(valley=9.120108734350652e-05)
lr = 1e-4

With our new optimizer, we will also want to use a different fit function, called fit_flat_cos

learn.fit_flat_cos(10, slice(lr))
epoch train_loss valid_loss acc_camvid time
0 0.599080 0.475873 0.263397 00:20
1 0.501408 0.387140 0.392524 00:19
2 0.408025 0.337689 0.494808 00:20
3 0.394905 0.325769 0.500459 00:19
4 0.339167 0.290880 0.546326 00:20
5 0.356163 0.281372 0.578445 00:20
6 0.325345 0.270683 0.584719 00:20
7 0.301687 0.248154 0.647950 00:20
8 0.278787 0.237091 0.637582 00:20
9 0.283964 0.229650 0.661018 00:20
learn.save('stage-1-mc-cg')   # Zach saves in case Colab dies / gives OOM
learn.load('stage-1-mc-cg');  # he reloads as a way of skipping what came before if he restarts the runtime.
learn.show_results(max_n=5, figsize=(12,6))

Let's unfreeze the model and decrease our learning rate by 4 (Rule of thumb)

lrs = slice(lr/400, lr/4)
lr, lrs
(0.0001, slice(2.5e-07, 2.5e-05, None))
learn.unfreeze()

And train for a bit more

learn.fit_flat_cos(12, lrs)
epoch train_loss valid_loss acc_camvid time
0 0.260153 0.233376 0.641235 00:21
1 0.257960 0.225589 0.677151 00:21
2 0.250809 0.225691 0.659791 00:21
3 0.258414 0.218329 0.681218 00:21
4 0.252307 0.225296 0.698392 00:21
5 0.255101 0.216371 0.686969 00:21
6 0.256417 0.215938 0.671820 00:21
7 0.244713 0.215095 0.691165 00:21
8 0.230717 0.206591 0.696377 00:21
9 0.237418 0.214904 0.724156 00:21
10 0.227629 0.203007 0.713009 00:21
11 0.224347 0.204781 0.714381 00:21

Now let's save that model away

learn.save('seg_1_mc_cg')
Path('models/seg_1_mc_cg.pth')
learn.load('seg_1_mc_cg')
<fastai.learner.Learner at 0x7f217c7b8670>

And look at a few results

learn.show_results(max_n=6, figsize=(18,8))

Inference

Let's take a look at how to do inference with test_dl

dl = learn.dls.test_dl(fnames[:5])
dl.show_batch()

Let's do the first five pictures

preds = learn.get_preds(dl=dl)
preds[0].shape
torch.Size([5, 12, 192, 256])

Alright so we have a 5x32x360x480

len(codes)
12

What does this mean? We had five images, so each one is one of our five images in our batch. Let's look at the first

pred_1 = preds[0][0]
pred_1.shape
torch.Size([12, 192, 256])

Now let's take the argmax of our values

pred_arx = pred_1.argmax(dim=0)

And look at it

plt.imshow(pred_arx)
<matplotlib.image.AxesImage at 0x7f21ed7b19d0>

What do we do from here? We need to save it away. We can do this one of two ways, as a numpy array to image, and as a tensor (to say use later rawly)

pred_arx = pred_arx.numpy()
rescaled = (255.0 / pred_arx.max() * (pred_arx - pred_arx.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im
im.save('test.png')

Let's make a function to do so for our files

for i, pred in enumerate(preds[0]):
  pred_arg = pred.argmax(dim=0).numpy()
  rescaled = (255.0 / pred_arg.max() * (pred_arg - pred_arg.min())).astype(np.uint8)
  im = Image.fromarray(rescaled)
  im.save(f'Image_{i}.png')
/tmp/ipykernel_291448/2136161077.py:3: RuntimeWarning: divide by zero encountered in true_divide
  rescaled = (255.0 / pred_arg.max() * (pred_arg - pred_arg.min())).astype(np.uint8)
/tmp/ipykernel_291448/2136161077.py:3: RuntimeWarning: invalid value encountered in multiply
  rescaled = (255.0 / pred_arg.max() * (pred_arg - pred_arg.min())).astype(np.uint8)

Now let's save away the raw:

torch.save(preds[0][0], 'Image_1.pt')
pred_1 = torch.load('Image_1.pt')
plt.imshow(pred_1.argmax(dim=0))
<matplotlib.image.AxesImage at 0x7f21ed7251c0>

Full Size (Homework)

Now let's go full sized. Restart your instance to re-free your memory

from fastai.vision.all import *
from espiownage.core import *
import glob 
path = Path('/home/drscotthawley/datasets/espiownage-cyclegan/')
path_im = path/'images'

path_lbl = path/maskdir
#path = untar_data('https://anonymized.machine.com/~drscotthawley/espiownage-cyclegan.tgz')
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)

get_msk = lambda o: path/maskdir/f'{o.stem}_P{o.suffix}'
colors = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
codes = [str(n) for n in colors]; codes
sz = (384, 512)
name2id = {v:int(v) for k,v in enumerate(codes)}
void_code = name2id['0']

def acc_camvid(inp, targ):  # original version from above
    targ = targ.squeeze(1)
    mask = targ != void_code
    return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()

And re-make our dataloaders. But this time we want our size to be the full size

cyclegan = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                   get_items=get_image_files,
                   splitter=RandomSplitter(),
                   get_y=get_msk,
                   batch_tfms=[*aug_transforms(size=sz), Normalize.from_stats(*imagenet_stats)])

We'll also want to lower our batch size to not run out of memory

dls = cyclegan.dataloaders(path/"images", fnames=fnames, bs=2)

Let's assign our vocab, make our learner, and load our weights

opt = ranger
def acc_camvid2(inp, targ):
    targ = targ.squeeze(1) 
    mask = targ != void_code  
    if len(targ[mask]) == 0:  mask = (targ == void_code)  # Empty image 
    return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()

dls.vocab = codes
learn = unet_learner(dls, resnet34, metrics=acc_camvid2, self_attention=True, act_cls=Mish, opt_func=opt)
learn.load('seg_1_mc_cg');

And now let's find our learning rate and train!

learn.lr_find()
SuggestedLRs(valley=4.365158383734524e-05)
lr = 4e-5
learn.fit_flat_cos(10, slice(lr))
epoch train_loss valid_loss acc_camvid2 time
0 0.348980 0.283500 0.487253 01:03
1 0.286115 0.288901 0.508010 01:03
2 0.279368 0.295443 0.475468 01:03
3 0.267749 0.274991 0.579448 01:03
4 0.278269 0.269600 0.520542 01:03
5 0.263586 0.305607 0.467321 01:03
6 0.250486 0.258819 0.533480 01:03
7 0.248353 0.266652 0.508037 01:03
8 0.224751 0.236934 0.575392 01:03
9 0.227134 0.247074 0.573274 01:03
learn.save('seg_1_mg_cg')
Path('models/seg_1_mg_cg.pth')
learn.unfreeze()
lrs = slice(1e-6,lr/10); lrs
slice(1e-06, 4.000000000000001e-06, None)
learn.fit_flat_cos(10, lrs)
epoch train_loss valid_loss acc_camvid2 time
0 0.216542 0.242073 0.561076 01:07
1 0.219427 0.244770 0.570780 01:07
2 0.198913 0.246989 0.567264 01:07
3 0.212870 0.237476 0.592173 01:08
4 0.204229 0.247322 0.574117 01:07
5 0.219369 0.231808 0.613849 01:07
6 0.187700 0.236111 0.596212 01:07
7 0.182536 0.223931 0.607258 01:07
8 0.186894 0.235464 0.588114 01:07
9 0.199770 0.237034 0.592395 01:07
learn.save('seg_2_mg_cg')
Path('models/seg_2_mg_cg.pth')
learn.show_results(max_n=4, figsize=(18,8))
interp = SegmentationInterpretation.from_learner(learn)
interp.plot_top_losses(k=5)

?Weighted Loss Functions? (not working)

We can use weighted loss functions to help with class imbalancing. We need to do this because simply oversampling won't quite work here! So, how do we do it? fastai's CrossEntropyLossFlat is just a wrapper around PyTorch's CrossEntropyLoss, so we can pass in a weight parameter (even if it doesn't show up in our autocompletion!)

class CrossEntropyLossFlat(BaseLoss):
    "Same as `nn.CrossEntropyLoss`, but flattens input and target."
    y_int = True
    def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

But what should this weight be? It needs to be a 1xn tensor, where n is the number of classes in your dataset. We'll use a quick example, where all but the last class has a weight of 90% and the last class has a weight of 110%

Also, as we are training on the GPU, we need the tensor to be so as well:

weights = torch.tensor([[0.9]*31 + [1.1]]).cuda()
weights
tensor([[0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000,
         0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000,
         0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000,
         0.9000, 0.9000, 0.9000, 0.9000, 1.1000]], device='cuda:0')

Now we can pass this into CrossEntropyLossFlat

  • Note: as this is segmentation, we need to make the axis to 1
learn.loss_func = CrossEntropyLossFlat(weight=weights, axis=1)

(or to pass it into cnn_learner)

loss_func = CrossEntropyLossFlat(weight=weights, axis=1)
learn = unet_learner(dls, resnet34, metrics=acc_camvid, loss_func=loss_func)
learn.lr_find()
0.00% [0/1 00:00<00:00]
0.00% [0/480 00:00<00:00]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_291448/2434377672.py in <module>
----> 1 learn.lr_find()

~/envs/espi/lib/python3.8/site-packages/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggest_funcs)
    280     n_epoch = num_it//len(self.dls.train) + 1
    281     cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 282     with self.no_logging(): self.fit(n_epoch, cbs=cb)
    283     if suggest_funcs is not None:
    284         lrs, losses = tensor(self.recorder.lrs[num_it//10:-5]), tensor(self.recorder.losses[num_it//10:-5])

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    219             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    220             self.n_epoch = n_epoch
--> 221             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    222 
    223     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
    210         for epoch in range(self.n_epoch):
    211             self.epoch=epoch
--> 212             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    213 
    214     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
    204 
    205     def _do_epoch(self):
--> 206         self._do_epoch_train()
    207         self._do_epoch_validate()
    208 

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
    196     def _do_epoch_train(self):
    197         self.dl = self.dls.train
--> 198         self._with_events(self.all_batches, 'train', CancelTrainException)
    199 
    200     def _do_epoch_validate(self, ds_idx=1, dl=None):

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
    167     def all_batches(self):
    168         self.n_iter = len(self.dl)
--> 169         for o in enumerate(self.dl): self.one_batch(*o)
    170 
    171     def _do_one_batch(self):

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
    192         b = self._set_device(b)
    193         self._split(b)
--> 194         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    195 
    196     def _do_epoch_train(self):

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/envs/espi/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
    173         self('after_pred')
    174         if len(self.yb):
--> 175             self.loss_grad = self.loss_func(self.pred, *self.yb)
    176             self.loss = self.loss_grad.clone()
    177         self('after_loss')

~/envs/espi/lib/python3.8/site-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
     33         if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
     34         if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 35         return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
     36 
     37     def to(self, device):

~/envs/espi/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/envs/espi/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
   1118 
   1119     def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1120         return F.cross_entropy(input, target, weight=self.weight,
   1121                                ignore_index=self.ignore_index, reduction=self.reduction)
   1122 

~/envs/espi/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2809     """
   2810     if has_torch_function_variadic(input, target):
-> 2811         return handle_torch_function(
   2812             cross_entropy,
   2813             (input, target),

~/envs/espi/lib/python3.8/site-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1250         # Use `public_api` instead of `implementation` so __torch_function__
   1251         # implementations can do equality/identity comparisons.
-> 1252         result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
   1253 
   1254         if result is not NotImplemented:

~/envs/espi/lib/python3.8/site-packages/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
    338         convert=False
    339         if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 340         res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
    341         if convert: res = convert(res)
    342         if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)

~/envs/espi/lib/python3.8/site-packages/torch/_tensor.py in __torch_function__(cls, func, types, args, kwargs)
   1021 
   1022         with _C.DisableTorchFunction():
-> 1023             ret = func(*args, **kwargs)
   1024             return _convert(ret, cls)
   1025 

~/envs/espi/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2822     if size_average is not None or reduce is not None:
   2823         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2825 
   2826 

RuntimeError: weight tensor should be defined either for all 12 classes or no classes but got weight tensor of shape: [1 x 32] at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:44
learn.fit_flat_cos(10, slice(lr))
learn.show_results(max_n=4, figsize=(18,8))