Here we take Zach Mueller's CAMVID Segmentation Tutorial and try to segment our real data as object vs background ("all one" class rather than multiple classes)

import torch, re 
tv, cv = torch.__version__, torch.version.cuda
tv = re.sub('\+cu.*','',tv)
TORCH_VERSION = 'torch'+tv[0:-1]+'0'
CUDA_VERSION = 'cu'+cv.replace('.','')

print(f"CUDA available = {torch.cuda.is_available()}, Device count = {torch.cuda.device_count()}, Current device = {torch.cuda.current_device()}")
print(f"Device name = {torch.cuda.get_device_name()}")
CUDA available = True, Device count = 1, Current device = 0
Device name = GeForce RTX 3080

This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, wwf, and espiownage currently running at the time of writing this:

  • fastai : 2.5.2
  • fastcore : 1.3.26
  • wwf : 0.0.16
  • espiownage : 0.0.36


from import *
from espiownage.core import *

Below you will find the exact imports for everything we use today

from fastcore.xtras import Path

from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos

from import DataBlock
from import untar_data, URLs
from import get_image_files, FuncSplitter, Normalize

from fastai.layers import Mish
from fastai.losses import BaseLoss
from fastai.optimizer import ranger

from fastai.torch_core import tensor

from import aug_transforms
from import PILImage, PILMask
from import ImageBlock, MaskBlock, imagenet_stats
from import unet_learner

from PIL import Image
import numpy as np

from torch import nn
from torchvision.models.resnet import resnet34

import torch
import torch.nn.functional as F


#path = untar_data('')
path = Path('/home/drscotthawley/datasets/espiownage-cyclegan/')
#path = Path('/home/drscotthawley/datasets/espiownage-cleaner/')  # not cleaned but clean-er than before!

Let's look at an image and see how everything aligns up

path_im = path/'images'
path_lbl = path/'masks'

First we need our filenames

import glob 
#fnames = get_image_files(path_im)
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)
len(meta_names), len(fnames), len(lbl_names)
(1200, 1200, 1200)

And now let's work with one of them

img_fn = fnames[10]
img = PILImage.create(img_fn),5))

Now let's grab our y's. They live in the labels folder and are denoted by a _P

get_msk = lambda o: path/'masks'/f'{o.stem}_P{o.suffix}'

The stem and suffix grab everything before and after the period respectively.

Our masks are of type PILMask and we will make our gradient percentage (alpha) equal to 1 as we are not overlaying this on anything yet

msk_name = get_msk(img_fn)
msk = PILMask.create(msk_name),5), alpha=1)

Now if we look at what our mask actually is, we can see it's a giant array of pixels:

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.uint8)

And just make sure that it's a simple file and not antialiasing. Let's see what values it contains:

{0, 1}

Where each one represents a class that we can find in codes.txt. Let's make a vocabulary with it

#colors = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110]
#colors = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
colors = list(set(np.array(msk).flatten()))
codes = [str(n) for n in range(len(colors))]; codes
['0', '1']

We need a split function that will split from our list of valid filenames we grabbed earlier. Let's try making our own.

This takes in our filenames, and checks for all of our filenames in all of our items in our validation filenames

Transfer Learning between DataSets

Jeremy popularized the idea of image resizing:

  • Train on smaller sized images
  • Eventually get larger and larger
  • Transfer Learning loop

This first round we will train at half the image size

sz = msk.shape; sz
(384, 512)
half = tuple(int(x/2) for x in sz); half
(192, 256)
cyclegan = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
    batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])
dls = cyclegan.dataloaders(path/'images', fnames=fnames, bs=4)
Let's look at a batch, and look at all the classes

dls.show_batch(max_n=4, vmin=0, vmax=2, figsize=(14,10))

Lastly let's make our vocabulary a part of our DataLoaders, as our loss function needs to deal with the Void label

dls.vocab = codes

Now we need a methodology for grabbing that particular code from our output of numbers. Let's make everything into a dictionary

name2id = {v:int(v) for k,v in enumerate(codes)}
{'0': 0, '1': 1}

Awesome! Let's make an accuracy function

void_code = name2id['0'] # name2id['Void']

For segmentation, we want to squeeze all the outputted values to have it as a matrix of digits for our segmentation mask. From there, we want to match their argmax to the target's mask for each pixel and take the average

def acc_camvid(inp, targ):
    targ = targ.squeeze(1)
    mask = targ != void_code
    if len(targ[mask]) == 0:  mask = (targ == void_code)  # Empty image (all void) 
    return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()

The Dynamic Unet

U-Net allows us to look at pixel-wise representations of our images through sizing it down and then blowing it bck up into a high resolution image. The first part we call an "encoder" and the second a "decoder"

On the image, the authors of the UNET paper describe the arrows as "denotions of different operations"

We have a special unet_learner. Something new is we can pass in some model configurations where we can declare a few things to customize it with!

  • Blur/blur final: avoid checkerboard artifacts
  • Self attention: A self-attention layer
  • y_range: Last activations go through a sigmoid for rescaling
  • Last cross - Cross-connection with the direct model input
  • Bottle - Bottlenck or not on that cross
  • Activation function
  • Norm type

Let's make a unet_learner that uses some of the new state of the art techniques. Specifically:

  • Self-attention layers: self_attention = True
  • Mish activation function: act_cls = Mish

Along with this we will use the Ranger as optimizer function.

opt = ranger
learn = unet_learner(dls, resnet34, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt)
DynamicUnet (Input shape: 4)
Layer (type)         Output Shape         Param #    Trainable 
                     4 x 64 x 96 x 128   
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
                     4 x 128 x 24 x 32   
Conv2d                                    73728      False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    8192       False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
                     4 x 256 x 12 x 16   
Conv2d                                    294912     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    32768      False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
                     4 x 512 x 6 x 8     
Conv2d                                    1179648    False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Conv2d                                    131072     False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
BatchNorm2d                               1024       True      
                     4 x 1024 x 6 x 8    
Conv2d                                    4719616    True      
                     4 x 512 x 6 x 8     
Conv2d                                    4719104    True      
                     4 x 1024 x 6 x 8    
Conv2d                                    525312     True      
BatchNorm2d                               512        True      
Conv2d                                    2359808    True      
Conv2d                                    2359808    True      
                     4 x 1024 x 12 x 16  
Conv2d                                    525312     True      
BatchNorm2d                               256        True      
Conv2d                                    1327488    True      
Conv2d                                    1327488    True      
                     4 x 48 x 768        
Conv1d                                    18432      True      
Conv1d                                    18432      True      
Conv1d                                    147456     True      
                     4 x 768 x 24 x 32   
Conv2d                                    295680     True      
BatchNorm2d                               128        True      
Conv2d                                    590080     True      
Conv2d                                    590080     True      
                     4 x 512 x 48 x 64   
Conv2d                                    131584     True      
BatchNorm2d                               128        True      
                     4 x 96 x 96 x 128   
Conv2d                                    165984     True      
Conv2d                                    83040      True      
                     4 x 384 x 96 x 128  
Conv2d                                    37248      True      
Conv2d                                    88308      True      
Conv2d                                    88308      True      
                     4 x 2 x 192 x 256   
Conv2d                                    200        True      

Total params: 41,405,488
Total trainable params: 20,137,840
Total non-trainable params: 21,267,648

Optimizer used: <function ranger at 0x7f3a0c7d0820>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

  - TrainEvalCallback
  - Recorder
  - ProgressCallback

If we do a learn.summary we can see this blow-up trend, and see that our model came in frozen. Let's find a learning rate

lr = 1e-4

With our new optimizer, we will also want to use a different fit function, called fit_flat_cos

learn.fit_flat_cos(12, slice(lr))
epoch train_loss valid_loss acc_camvid time
0 0.257652 0.199433 0.723017 00:21
1 0.195187 0.168288 0.704716 00:19
2 0.180804 0.149645 0.781627 00:19
3 0.166444 0.140235 0.826935 00:19
4 0.160800 0.139796 0.814588 00:19
5 0.150028 0.136093 0.785896 00:19
6 0.141635 0.129462 0.839787 00:19
7 0.142565 0.127221 0.847131 00:19
8 0.127422 0.123125 0.857813 00:19
9 0.126634 0.122602 0.792614 00:19
10 0.128560 0.117257 0.827089 00:19
11 0.128711 0.116848 0.852421 00:19'stage-1-real-cg')   # Zach saves in case Colab dies / gives OOM
learn.load('stage-1-real-cg');  # he reloads as a way of skipping what came before if he restarts the runtime.
learn.show_results(max_n=4, figsize=(12,6))

Let's unfreeze the model and decrease our learning rate by 4 (Rule of thumb)

lrs = slice(lr/400, lr/4)
lr, lrs
(0.0001, slice(2.5e-07, 2.5e-05, None))

And train for a bit more

learn.fit_flat_cos(12, lrs)
epoch train_loss valid_loss acc_camvid time
0 0.106870 0.106321 0.866683 00:20
1 0.109138 0.105891 0.856629 00:20
2 0.109416 0.107211 0.862893 00:20
3 0.108425 0.104959 0.871750 00:20
4 0.104698 0.105199 0.860683 00:20
5 0.104129 0.105591 0.873060 00:20
6 0.108058 0.104724 0.866128 00:20
7 0.100461 0.103512 0.874267 00:20
8 0.098341 0.105129 0.862853 00:20
9 0.099847 0.101906 0.866114 00:20
10 0.095101 0.101210 0.867531 00:20
11 0.096140 0.102163 0.870525 00:20

Now let's save that model away'model_1_cg')
<fastai.learner.Learner at 0x7f39f97d21f0>

And look at a few results

learn.show_results(max_n=6, figsize=(10,10))


Let's take a look at how to do inference with test_dl

dl = learn.dls.test_dl(fnames[0:6])

Let's do the first five pictures

preds = learn.get_preds(dl=dl)
torch.Size([6, 2, 192, 256])

Alright so we have a 5x32x360x480


What does this mean? We had five images, so each one is one of our five images in our batch. Let's look at the first

ind = 5
pred_1 = preds[0][ind]
torch.Size([2, 192, 256])

Now let's take the argmax of our values

pred_arx = pred_1.argmax(dim=0)

And look at it

<matplotlib.image.AxesImage at 0x7f39db99ca30>

What do we do from here? We need to save it away. We can do this one of two ways, as a numpy array to image, and as a tensor (to say use later rawly)

pred_arx = pred_arx.numpy()
rescaled = (255.0 / pred_arx.max() * (pred_arx - pred_arx.min())).astype(np.uint8)
im = Image.fromarray(rescaled)

Let's make a function to do so for our files

for i, pred in enumerate(preds[0]):
  pred_arg = pred.argmax(dim=0).numpy()
  rescaled = (255.0 / pred_arg.max() * (pred_arg - pred_arg.min())).astype(np.uint8)
  im = Image.fromarray(rescaled)'Image_{i}_real.png')
Now let's save away the raw:[0][ind], '')
pred_1 = torch.load('')
<matplotlib.image.AxesImage at 0x7f39f86ad880>

Full Size

Now let's go full sized. Restart your instance to re-free your memory

from import *
from espiownage.core import *
import glob 
path = Path('/home/drscotthawley/datasets/espiownage-cyclegan/')
path_im = path/'images'
path_lbl = path/'masks'
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)

get_msk = lambda o: path/'masks'/f'{o.stem}_P{o.suffix}'
colors = [0, 1]
codes = [str(n) for n in colors]; codes
sz = (384, 512)
name2id = {v:int(v) for k,v in enumerate(codes)}
void_code = name2id['0']

And re-make our dataloaders. But this time we want our size to be the full size

seg_db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                   batch_tfms=[*aug_transforms(size=sz), Normalize.from_stats(*imagenet_stats)])

We'll also want to lower our batch size to not run out of memory

dls = seg_db.dataloaders(path/"images", fnames=fnames, bs=2)

Let's assign our vocab, make our learner, and load our weights

opt = ranger
def acc_camvid2(inp, targ, warn=False):
    targ = targ.squeeze(1) 
    mask = targ != void_code  # where it's nonzero
    if len(targ[mask]) == 0:  # Empty image (all void)
        mask = (targ == void_code)  
        if warn:
            acc_empty = (inp.argmax(dim=1)[mask]==targ[mask]).float().mean() # score based on what's correct overall (~100%?)
            print("Empty image, acc_empty = ",acc_empty.cpu().numpy())
    return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean() 

def acc_camvid3(inp, targ):
    mask = inp.argmax(dim=1) == targ.squeeze(1) # could give inflated scores for images dominated by void
    return mask.float().mean()

dls.vocab = codes
learn = unet_learner(dls, resnet34, metrics=acc_camvid2, self_attention=True, act_cls=Mish, opt_func=opt)

And now let's find our learning rate and train!

lr = 1e-4
learn.fit_flat_cos(10, slice(lr))
epoch train_loss valid_loss acc_camvid2 time
0 0.171015 0.154582 0.749763 01:01
1 0.155260 0.142730 0.798156 01:02
2 0.129668 0.134481 0.837203 01:02
3 0.117391 0.125572 0.837247 01:02
4 0.114318 0.125483 0.824614 01:02
5 0.118929 0.121923 0.834170 01:02
6 0.105823 0.114306 0.844533 01:02
7 0.105705 0.115865 0.812557 01:02
8 0.095831 0.116311 0.843808 01:02
9 0.088960 0.113454 0.843645 01:02'seg_full_1_cg')
lrs = slice(1e-6,lr/10); lrs
slice(1e-06, 1e-05, None)
learn.fit_flat_cos(10, lrs)
epoch train_loss valid_loss acc_camvid2 time
0 0.081679 0.110383 0.851277 01:06
1 0.087923 0.107981 0.847036 01:06
2 0.083881 0.112153 0.842376 01:06
3 0.076471 0.110485 0.838301 01:06
4 0.079998 0.104261 0.854367 01:06
5 0.091270 0.104432 0.861100 01:06
6 0.083442 0.100981 0.868755 01:06
7 0.076202 0.104505 0.855163 01:06
8 0.082528 0.102581 0.850889 01:06
9 0.073828 0.107803 0.847006 01:06'seg_full_2_cg')
learn.show_results(max_n=4, figsize=(18,8))
interp = SegmentationInterpretation.from_learner(learn)
nplot, x = 10, 1
for i in interp.top_losses(nplot).indices:
    print(f"[{x}] {dls.valid_ds.items[i]}")
    x += 1
[1] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000461.png
[2] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000764.png
[3] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000297.png
[4] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0040147.png
[5] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000630.png
[6] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000234.png
[7] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000011.png
[8] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000507.png
[9] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000849.png
[10] /home/drscotthawley/datasets/espiownage-cyclegan/images/steelpan_0000802.png
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
print(preds.shape, targs.shape)
torch.Size([240, 2, 384, 512]) torch.Size([240, 384, 512])
def save_tmask(tmask, fname, argmax=True):
    "save tensor mask"
    tmask_new = tmask.argmax(dim=0).cpu().numpy() if argmax else tmask.cpu().numpy()
    rescaled = (255.0 / tmask_new.max() * (tmask_new - tmask_new.min())).astype(np.uint8)
    im = Image.fromarray(rescaled)
seg_img_dir = 'seg_images_cg'
!rm -rf {seg_img_dir}; mkdir {seg_img_dir}
results = []
for i in range(len(preds)):
    filestem = dls.valid.items[i].stem
    line_list = [filestem]+[losses[i].cpu().numpy(), i]
    save_tmask(preds[i], seg_img_dir+'/'+filestem+'_pred.png')

# store as pandas dataframe
res_df = pd.DataFrame(results, columns=['filename', 'loss','i'])
res_df = res_df.sort_values('loss', ascending=False)
res_df.to_csv('segmentation_top_losses_cg.csv', index=False)