import torch, re 
tv, cv = torch.__version__, torch.version.cuda
tv = re.sub('\+cu.*','',tv)
TORCH_VERSION = 'torch'+tv[0:-1]+'0'
CUDA_VERSION = 'cu'+cv.replace('.','')

print(f"TORCH_VERSION={TORCH_VERSION}; CUDA_VERSION={CUDA_VERSION}")
print(f"CUDA available = {torch.cuda.is_available()}, Device count = {torch.cuda.device_count()}, Current device = {torch.cuda.current_device()}")
print(f"Device name = {torch.cuda.get_device_name()}")
print("hostname:")
!hostname
TORCH_VERSION=torch1.9.0; CUDA_VERSION=cu111
CUDA available = True, Device count = 1, Current device = 0
Device name = GeForce RTX 3080
hostname:
bengio

This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, wwf, and espiownage currently running at the time of writing this:

  • fastai : 2.5.2
  • fastcore : 1.3.26
  • wwf : 0.0.16
  • espiownage : 0.0.36

Libraries

from fastai.vision.all import *
from espiownage.core import *

Below you will find the exact imports for everything we use today

from fastcore.xtras import Path

from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos

from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize

from fastai.layers import Mish
from fastai.losses import BaseLoss
from fastai.optimizer import ranger

from fastai.torch_core import tensor

from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner

from PIL import Image
import numpy as np

from torch import nn
from torchvision.models.resnet import resnet34

import torch
import torch.nn.functional as F

Dataset

#path = untar_data('https://anonymized.machine.com/~drscotthawley/espiownage-cyclegan.tgz')
path = Path('/home/drscotthawley/datasets/espiownage-fake/')
#path = Path('/home/drscotthawley/datasets/espiownage-cleaner/')  # not cleaned but clean-er than before!

Let's look at an image and see how everything aligns up

path_im = path/'images'
path_lbl = path/'masks'

First we need our filenames

import glob 
#fnames = get_image_files(path_im)
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)
len(meta_names), len(fnames), len(lbl_names)
(2000, 2000, 2000)

And now let's work with one of them

img_fn = fnames[10]
print(img_fn)
img = PILImage.create(img_fn)
img.show(figsize=(5,5))
/home/drscotthawley/datasets/espiownage-fake/images/steelpan_0000010.png
<AxesSubplot:>

Now let's grab our y's. They live in the labels folder and are denoted by a _P

get_msk = lambda o: path/'masks'/f'{o.stem}_P{o.suffix}'

The stem and suffix grab everything before and after the period respectively.

Our masks are of type PILMask and we will make our gradient percentage (alpha) equal to 1 as we are not overlaying this on anything yet

msk_name = get_msk(img_fn)
print(msk_name)
msk = PILMask.create(msk_name)
msk.show(figsize=(5,5), alpha=1)
/home/drscotthawley/datasets/espiownage-fake/masks/steelpan_0000010_P.png
<AxesSubplot:>

Now if we look at what our mask actually is, we can see it's a giant array of pixels:

print(tensor(msk))
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.uint8)

And just make sure that it's a simple file and not antialiasing. Let's see what values it contains:

set(np.array(msk).flatten())
{0, 1}

Where each one represents a class that we can find in codes.txt. Let's make a vocabulary with it

#colors = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110]
#colors = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
colors = list(set(np.array(msk).flatten()))
codes = [str(n) for n in range(len(colors))]; codes
['0', '1']

We need a split function that will split from our list of valid filenames we grabbed earlier. Let's try making our own.

This takes in our filenames, and checks for all of our filenames in all of our items in our validation filenames

Transfer Learning between DataSets

Jeremy popularized the idea of image resizing:

  • Train on smaller sized images
  • Eventually get larger and larger
  • Transfer Learning loop

This first round we will train at half the image size

sz = msk.shape; sz
(384, 512)
half = tuple(int(x/2) for x in sz); half
(192, 256)
cyclegan = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
    get_items=get_image_files,
    splitter=RandomSplitter(),
    get_y=get_msk,
    batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])
dls = cyclegan.dataloaders(path/'images', fnames=fnames, bs=4)
/home/drscotthawley/.local/lib/python3.8/site-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
/home/drscotthawley/.local/lib/python3.8/site-packages/torch/_tensor.py:1023: UserWarning: torch.solve is deprecated in favor of torch.linalg.solveand will be removed in a future PyTorch release.
torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  /pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp:760.)
  ret = func(*args, **kwargs)

Let's look at a batch, and look at all the classes

dls.show_batch(max_n=4, vmin=0, vmax=2, figsize=(14,10))

Lastly let's make our vocabulary a part of our DataLoaders, as our loss function needs to deal with the Void label

dls.vocab = codes

Now we need a methodology for grabbing that particular code from our output of numbers. Let's make everything into a dictionary

name2id = {v:int(v) for k,v in enumerate(codes)}
name2id
{'0': 0, '1': 1}

Awesome! Let's make an accuracy function

void_code = name2id['0'] # name2id['Void']

For segmentation, we want to squeeze all the outputted values to have it as a matrix of digits for our segmentation mask. From there, we want to match their argmax to the target's mask for each pixel and take the average

def acc_camvid(inp, targ):
    targ = targ.squeeze(1)
    mask = targ != void_code
    if len(targ[mask]) == 0:  mask = (targ == void_code)  # Empty image (all void) 
    return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()

The Dynamic Unet

Let's make a unet_learner that uses some of the new state of the art techniques. Specifically:

  • Self-attention layers: self_attention = True
  • Mish activation function: act_cls = Mish

Along with this we will use the Ranger as optimizer function.

opt = ranger
learn = unet_learner(dls, resnet34, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt)
/home/drscotthawley/.local/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
learn.lr_find()
SuggestedLRs(valley=0.00019054606673307717)
lr = 1e-4

With our new optimizer, we will also want to use a different fit function, called fit_flat_cos

learn.fit_flat_cos(12, slice(lr))
epoch train_loss valid_loss acc_camvid time
0 0.147378 0.136856 0.790030 00:30
1 0.105581 0.090802 0.897868 00:30
2 0.092015 0.076377 0.906366 00:30
3 0.081409 0.073532 0.906909 00:30
4 0.074445 0.064937 0.929686 00:30
5 0.071243 0.063878 0.927400 00:30
6 0.060489 0.057416 0.943657 00:31
7 0.063730 0.062996 0.931544 00:31
8 0.060993 0.053980 0.935233 00:30
9 0.057328 0.051203 0.951352 00:31
10 0.052732 0.051565 0.939421 00:31
11 0.053402 0.051502 0.941940 00:31
learn.save('stage-1-real-fake2')   # Zach saves in case Colab dies / gives OOM
learn.load('stage-1-real-fake2');  # he reloads as a way of skipping what came before if he restarts the runtime.
learn.show_results(max_n=4, figsize=(12,6))

Let's unfreeze the model and decrease our learning rate by 4 (Rule of thumb)

lrs = slice(lr/400, lr/4)
lr, lrs
(0.0001, slice(2.5e-07, 2.5e-05, None))
learn.unfreeze()

And train for a bit more

learn.fit_flat_cos(12, lrs)
epoch train_loss valid_loss acc_camvid time
0 0.051029 0.048564 0.946476 00:33
1 0.047924 0.049289 0.948122 00:33
2 0.047975 0.048101 0.953271 00:33
3 0.049226 0.048499 0.950610 00:33
4 0.048128 0.048449 0.947887 00:33
5 0.048525 0.047438 0.947886 00:33
6 0.047652 0.045650 0.947648 00:33
7 0.045980 0.048731 0.944906 00:33
8 0.044486 0.045627 0.949069 00:33
9 0.043657 0.044666 0.952414 00:33
10 0.044572 0.047426 0.947185 00:33
11 0.043687 0.046461 0.948212 00:33

Now let's save that model away

learn.save('model_1_fake2')
Path('models/model_1_fake2.pth')
learn.load('model_1_fake2')
<fastai.learner.Learner at 0x7f21fc35b3a0>

And look at a few results

learn.show_results(max_n=6, figsize=(10,10))

Inference

Let's take a look at how to do inference with test_dl

dl = learn.dls.test_dl(fnames[0:6])
dl.show_batch()

Let's do the first five pictures

preds = learn.get_preds(dl=dl)
len(preds)
2
preds[0].shape
torch.Size([6, 2, 192, 256])

Alright so we have a 5x32x360x480

len(codes)
2

What does this mean? We had five images, so each one is one of our five images in our batch. Let's look at the first

ind = 5
pred_1 = preds[0][ind]
pred_1.shape
torch.Size([2, 192, 256])

Now let's take the argmax of our values

pred_arx = pred_1.argmax(dim=0)

And look at it

plt.imshow(pred_arx)
<matplotlib.image.AxesImage at 0x7f225cc313a0>

What do we do from here? We need to save it away. We can do this one of two ways, as a numpy array to image, and as a tensor (to say use later rawly)

pred_arx = pred_arx.numpy()
rescaled = (255.0 / pred_arx.max() * (pred_arx - pred_arx.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im
im.save('test_fake2.png')

Let's make a function to do so for our files

for i, pred in enumerate(preds[0]):
  pred_arg = pred.argmax(dim=0).numpy()
  rescaled = (255.0 / pred_arg.max() * (pred_arg - pred_arg.min())).astype(np.uint8)
  im = Image.fromarray(rescaled)
  im.save(f'Image_{i}_fake2.png')

Now let's save away the raw:

torch.save(preds[0][ind], 'Image_1_fake2.pt')
pred_1 = torch.load('Image_1_fake2.pt')
plt.imshow(pred_1.argmax(dim=0))
<matplotlib.image.AxesImage at 0x7f21fc2d0eb0>

Full Size

Now let's go full sized. Restart your instance to re-free your memory

from fastai.vision.all import *
from espiownage.core import *
import glob 
path = Path('/home/drscotthawley/datasets/espiownage-fake/')
path_im = path/'images'
path_lbl = path/'masks'
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)

get_msk = lambda o: path/'masks'/f'{o.stem}_P{o.suffix}'
colors = [0, 1]
codes = [str(n) for n in colors]; codes
sz = (384, 512)
name2id = {v:int(v) for k,v in enumerate(codes)}
void_code = name2id['0']

And re-make our dataloaders. But this time we want our size to be the full size

seg_db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                   get_items=get_image_files,
                   splitter=RandomSplitter(),
                   get_y=get_msk,
                   batch_tfms=[*aug_transforms(size=sz), Normalize.from_stats(*imagenet_stats)])

We'll also want to lower our batch size to not run out of memory

dls = seg_db.dataloaders(path/"images", fnames=fnames, bs=2)

Let's assign our vocab, make our learner, and load our weights

opt = ranger
def acc_camvid2(inp, targ):
    targ = targ.squeeze(1) 
    mask = targ != void_code  
    if len(targ[mask]) == 0:  mask = (targ == void_code)  # Empty image 
    return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()

dls.vocab = codes
learn = unet_learner(dls, resnet34, metrics=acc_camvid2, self_attention=True, act_cls=Mish, opt_func=opt)
learn.load('model_1_real');

And now let's find our learning rate and train!

learn.lr_find()
SuggestedLRs(valley=7.585775892948732e-05)
lr = 7e-5
learn.fit_flat_cos(10, slice(lr))
epoch train_loss valid_loss acc_camvid2 time
0 0.088198 0.078247 0.876046 01:40
1 0.070138 0.064181 0.893534 01:40
2 0.062380 0.059236 0.929729 01:40
3 0.049780 0.052593 0.921827 01:41
4 0.044831 0.051482 0.929403 01:41
5 0.044021 0.048819 0.931578 01:40
6 0.043577 0.049762 0.930833 01:40
7 0.045576 0.043703 0.940147 01:40
8 0.037397 0.040118 0.943982 01:40
9 0.036702 0.040967 0.944299 01:40
learn.save('seg_full_1_fake2')
Path('models/seg_full_1_fake2.pth')
learn.unfreeze()
lrs = slice(1e-6,lr/10); lrs
slice(1e-06, 6.999999999999999e-06, None)
learn.fit_flat_cos(10, lrs)
epoch train_loss valid_loss acc_camvid2 time
0 0.037385 0.040444 0.943076 01:47
1 0.035534 0.042469 0.939450 01:47
2 0.033553 0.039952 0.944300 01:47
3 0.034228 0.040184 0.942997 01:47
4 0.034200 0.040859 0.941651 01:47
5 0.034134 0.040110 0.942516 01:47
6 0.032345 0.037883 0.947169 01:47
7 0.030923 0.038920 0.944714 01:47
8 0.032584 0.038794 0.945021 01:47
9 0.031989 0.039260 0.945023 01:47
learn.save('seg_full_2_fake2')
Path('models/seg_full_2_fake2.pth')
learn.show_results(max_n=4, figsize=(18,8))
interp = SegmentationInterpretation.from_learner(learn)
nplot, x = 10, 1
for i in interp.top_losses(nplot).indices:
    print(f"[{x}] {dls.valid_ds.items[i]}")
    x += 1
interp.plot_top_losses(k=nplot)
[1] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0000432.png
[2] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0001433.png
[3] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0000172.png
[4] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0001201.png
[5] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0001646.png
[6] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0001973.png
[7] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0000190.png
[8] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0000176.png
[9] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0000929.png
[10] /home/drscotthawley/datasets/espiownage-fake/images/steelpan_0000623.png
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
print(preds.shape, targs.shape)
len(preds)
torch.Size([400, 2, 384, 512]) torch.Size([400, 384, 512])
400