Acknowledgement: I took Zach Mueller's Image Segmentation tutoral notebook (based on the main FastAI lesson notebook) and modified it to do regression (as per Zach's suggestions) and to work with my own data.

 
!pip install -Uqq espiownage

Libraries

from fastai.vision.all import *
from espiownage.core import *

Below you will find the exact imports for everything we use today

from fastcore.xtras import Path

from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos

from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize

from fastai.layers import Mish   # MishJIT gives me trouble :-( 
from fastai.losses import BaseLoss, MSELossFlat, CrossEntropyLossFlat, BCEWithLogitsLossFlat
from fastai.optimizer import ranger

from fastai.torch_core import tensor

from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner

from PIL import Image
import numpy as np
import random

from torch import nn
from torchvision.models.resnet import resnet34

import torch
import torch.nn.functional as F

import glob
from pathlib import Path
import torch, re 
tv, cv = torch.__version__, torch.version.cuda
tv = re.sub('\+cu.*','',tv)
TORCH_VERSION = 'torch'+tv[0:-1]+'0'
CUDA_VERSION = 'cu'+cv.replace('.','')

print(f"TORCH_VERSION={TORCH_VERSION}; CUDA_VERSION={CUDA_VERSION}")
print(f"CUDA available = {torch.cuda.is_available()}, Device count = {torch.cuda.device_count()}, Current device = {torch.cuda.current_device()}")
print(f"Device name = {torch.cuda.get_device_name()}")
print("hostname:")
!hostname
TORCH_VERSION=torch1.9.0; CUDA_VERSION=cu111
CUDA available = True, Device count = 1, Current device = 0
Device name = GeForce RTX 3080
hostname:
bengio

Dataset

Todays dataset will be CAMVID, which is a segmentation based problem from cameras on cars to segment various areas of the road

path = Path('/home/drscotthawley/datasets/espiownage-cleaner/')  # real data is local and private 

We can generate masks dynamically using espiownage's gen_masks script:

bin_size = 0.7  
maskdir = path / ('masks_'+str(bin_size))
!gen_masks --quiet --step={bin_size} --maskdir={maskdir} --files={str(path/'annotations')+'/*.csv'}
path_im = path/'images'
path_lbl = path/maskdir
 
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
random.shuffle(fnames)
lbl_names = get_image_files(path_lbl)
len(meta_names), len(fnames), len(lbl_names)
(1955, 1955, 1955)
img_fn = fnames[1]
img = PILImage.create(img_fn)
img.show(figsize=(5,5))
<AxesSubplot:>
get_msk = lambda o: path/maskdir/f'{o.stem}_P{o.suffix}'

The stem and suffix grab everything before and after the period respectively.

The segmentation masks are not floating point values, rather they're integers obtained by "binning" the ring counts by bin_size, then generating integers as int(ring_count/bin_size).

Our masks are of type PILMask and we will make our gradient percentage (alpha) equal to 1 as we are not overlaying this on anything yet

msk = PILMask.create(get_msk(img_fn))
msk.show(figsize=(5,5), alpha=1)
<AxesSubplot:>

Now if we look at what our mask actually is, we can see it's a giant array of pixels:

tensor(msk)
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.uint8)
set(np.array(msk).flatten())
{0, 1, 5, 11}

the "colors" are the integer values of the quantized ring counts (rescaled by the bin_size)

colors = list(range(int(11/bin_size) + 1))
colors
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

....and because this is based on a classification model, there are text label "codes" (e.g. "dog", "cat") which for us are just the bin-integer mask values, all over again:

codes = [str(n) for n in range(len(colors))]; codes
['0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15']
yrange = len(codes); yrange
16

Transfer Learning between DataSets

Jeremy popularized the idea of image resizing:

  • Train on smaller sized images
  • Eventually get larger and larger
  • Transfer Learning loop

This first round we will train at half the image size

sz = msk.shape; sz
(384, 512)
half = tuple(int(x/2) for x in sz); half
(192, 256)
db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
    get_items=get_image_files,
    splitter=RandomSplitter(),
    get_y=get_msk,
    batch_tfms=[*aug_transforms(size=half, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
dls = db.dataloaders(path/'images', fnames=fnames, bs=4)
/home/drscotthawley/.local/lib/python3.8/site-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
/home/drscotthawley/.local/lib/python3.8/site-packages/torch/_tensor.py:1023: UserWarning: torch.solve is deprecated in favor of torch.linalg.solveand will be removed in a future PyTorch release.
torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  /pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp:760.)
  ret = func(*args, **kwargs)

Let's look at a batch, and look at all the classes between codes 1 and 30 (ignoring Animal and Wall)

dls.show_batch(max_n=4, vmin=1, vmax=len(codes), figsize=(14,10))

Lastly let's make our vocabulary a part of our DataLoaders, as our loss function needs to deal with the Void label

dls.vocab = codes
name2id = {v:k for k,v in enumerate(codes)}
name2id
{'0': 0,
 '1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 '10': 10,
 '11': 11,
 '12': 12,
 '13': 13,
 '14': 14,
 '15': 15}
void_code = name2id['0']
def sr_acc_old(inp, targ):          # scores both voids and objects
    targ = targ.squeeze(1)
    return 1 - (inp-targ).abs().round().clamp(max=1).mean() 

def sr_acc(inp, targ, bin_size=1):
    "segmentation regression accuracy: Are we within +/- bin_size?  tries to score only objects, not voids"
    targ = targ.squeeze(1)  
    inp,targ = flatten_check(inp,targ) # https://docs.fast.ai/metrics.html#flatten_check
    mask = targ != void_code  # non_voids
    if len(targ[mask]) == 0:  # Empty image (all void)
        where_correct = (inp-targ).abs() < bin_size              # gonna be ~100%!
    else:
        where_correct = (inp[mask]-targ[mask]).abs() < bin_size  # don't count voids in metric
    return where_correct.float().mean()

# Cell
def sr_acc05(inp, targ): return sr_acc(inp, targ, bin_size=0.5)

# Cell
def sr_acc1(inp, targ): return sr_acc(inp, targ, bin_size=1)

# Cell
def sr_acc15(inp, targ): return sr_acc(inp, targ, bin_size=1.5)

# Cell
def sr_acc2(inp, targ): return sr_acc(inp, targ, bin_size=2)
opt = ranger

UNet_learner

We turn the classifier into a regression model by specifying on output "class" and then scaling the sigmoid/softmas to the range of values we want..

hrfac = 1.2  # 'headroom factor'
y_range=(0,int(len(codes)*hrfac))  # balance between "clamping" to range of real data vs too much "compression" from sigmoid nonlineari
metrics = [mae, sr_acc_old, sr_acc05, sr_acc1, sr_acc15, sr_acc2]
learn = unet_learner(dls, resnet34, n_out=1, y_range=y_range, loss_func=MSELossFlat(), metrics=metrics, self_attention=True, act_cls=Mish, opt_func=opt)
/home/drscotthawley/.local/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
lr = learn.lr_find().valley
print("Suggested Learning Rate =",lr)
Suggested Learning Rate = 0.0012022644514217973
lr = 1e-3

With our new optimizer, we will also want to use a different fit function, called fit_flat_cos

learn.fit_flat_cos(12, slice(lr))  # these frozen epochs don't yield much improvement btw
epoch train_loss valid_loss mae sr_acc_old sr_acc05 sr_acc1 sr_acc15 sr_acc2 time
0 7.089978 5.491747 0.979635 0.668919 0.110545 0.296386 0.402533 0.508916 00:29
1 6.036098 5.456071 0.932821 0.675291 0.119492 0.314745 0.425798 0.541563 00:28
2 5.697878 5.102298 1.000938 0.627256 0.154435 0.332985 0.440755 0.530337 00:29
3 5.267841 5.137025 0.943078 0.654377 0.153771 0.344551 0.460436 0.555873 00:29
4 4.952352 5.093339 0.892314 0.686509 0.133756 0.304471 0.397107 0.503860 00:28
5 4.834439 4.736519 0.896717 0.659961 0.170082 0.359108 0.477284 0.569669 00:29
6 4.881256 4.808916 0.891781 0.670739 0.148911 0.336322 0.439235 0.532952 00:29
7 4.894568 4.833971 0.875522 0.673204 0.158937 0.357425 0.481381 0.591281 00:28
8 4.884807 4.829686 0.819849 0.703676 0.137873 0.331885 0.425504 0.520800 00:28
9 4.780081 4.685006 0.839431 0.682715 0.147457 0.365742 0.465724 0.555164 00:28
10 4.396173 4.550979 0.857806 0.664973 0.210451 0.393658 0.498187 0.584454 00:28
11 4.103733 4.469028 0.794565 0.682472 0.216281 0.399096 0.501662 0.596793 00:29
learn.save('seg_reg_real_1')   
learn.load('seg_reg_real_1')
<fastai.learner.Learner at 0x7fea4d1ae9d0>
learn.show_results(max_n=6, figsize=(10,10))

^^ if the right column above is all blue, that's bad. but let's unfreeze and see if the model can learn..

Let's unfreeze the model and decrease our learning rate by 4 (Rule of thumb)

lrs = slice(lr/400, lr/4)
lr, lrs
(0.001, slice(2.5e-06, 0.00025, None))
learn.unfreeze()

And train for a bit more

learn.fit_flat_cos(12, lrs)
epoch train_loss valid_loss mae sr_acc_old sr_acc05 sr_acc1 sr_acc15 sr_acc2 time
0 4.444946 4.495612 0.785001 0.687291 0.206691 0.390643 0.500603 0.600335 00:31
1 4.126898 4.477420 0.826504 0.671112 0.198250 0.392698 0.500209 0.583497 00:31
2 4.086650 4.336927 0.763674 0.690917 0.198003 0.381867 0.483945 0.571887 00:31
3 4.329653 4.261275 0.794148 0.686370 0.191284 0.369669 0.467303 0.550883 00:31
4 4.063001 4.255717 0.767850 0.686955 0.206374 0.398319 0.500243 0.591661 00:31
5 3.898715 4.132246 0.804221 0.668942 0.186322 0.384170 0.496879 0.589402 00:31
6 4.197900 4.286320 0.825389 0.659126 0.224167 0.397995 0.514112 0.637275 00:31
7 3.795408 3.967359 0.742274 0.684181 0.226704 0.416520 0.518135 0.599069 00:31
8 3.967145 4.040976 0.785223 0.672122 0.208898 0.388723 0.501708 0.613508 00:31
9 3.817849 3.801829 0.700488 0.696286 0.202370 0.382298 0.493097 0.640270 00:31
10 3.573793 3.671768 0.713289 0.682646 0.221985 0.395955 0.503925 0.628679 00:31
11 3.429790 3.710968 0.707290 0.685755 0.224257 0.403588 0.513950 0.663363 00:31
learn.save('seg_reg_real_2')
Path('models/seg_reg_real_2.pth')
learn.show_results(max_n=6, figsize=(10,10))

Inference

Let's take a look at how to do inference with test_dl

dl = learn.dls.test_dl(fnames[:5]) 
dl.show_batch()

Let's do the first five pictures

preds = learn.get_preds(dl=dl)
preds[0].shape
torch.Size([5, 1, 192, 256])
pred_1 = preds[0][0].squeeze()
pred_1.shape
torch.Size([192, 256])
msk = PILMask.create(pred_1)
msk.show(figsize=(5,5), alpha=1)
<AxesSubplot:>
pred_arx = pred_1.numpy()
def save_tmask(tmask, fname='', norm=False): # save tensor mask
    tmask_new = tmask[0].squeeze().cpu().numpy() 
    use_min, use_max = 0, np.max(np.array(colors))    # use scale of max ring count
    if norm: use_min, use_max = tmask_new.min(), tmask_new.max()   # auto scale for just this image
    rescaled = (255.0 / use_max * (tmask_new - use_min)).astype(np.uint8)
    im = Image.fromarray(rescaled)
    if fname != '': im.save(fname)
    return im
im = save_tmask(preds[0])
im
plt.imshow(im)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7fe934352be0>

Here we have actual ring values. Note that without vmin=...,vmax=... kwargs, plt.imshow will auto-normalize the colors to take up the maximum range for this image alone.

plt.imshow(pred_arx*bin_size)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7fe934305490>

Actually that's not too bad!

Full Size

db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
    get_items=get_image_files,
    splitter=RandomSplitter(),
    get_y=get_msk,
    batch_tfms=[*aug_transforms(size=sz, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
dls = db.dataloaders(path/'images', fnames=fnames, bs=2)  # smaller batch size because we're now full size
dls.vocab = codes
learn = unet_learner(dls, resnet34, n_out=1, y_range=y_range, loss_func=MSELossFlat(), metrics=metrics, self_attention=True, act_cls=Mish, opt_func=opt)
learn.load('seg_reg_real_2')
<fastai.learner.Learner at 0x7fe9342206d0>
learn.lr_find(end_lr=5e-3)
SuggestedLRs(valley=0.00014070511679165065)
lr = 3e-4
learn.fit_flat_cos(10, slice(lr))
epoch train_loss valid_loss mae sr_acc_old sr_acc05 sr_acc1 sr_acc15 sr_acc2 time
0 5.386307 5.372129 1.013704 0.649748 0.127289 0.332420 0.436476 0.517089 01:31
1 4.682086 5.096883 0.927518 0.679011 0.132360 0.332205 0.435346 0.518651 01:30
2 4.757010 5.169637 0.917902 0.692183 0.166780 0.337664 0.426051 0.547509 01:30
3 4.144821 4.934378 0.859546 0.702090 0.173864 0.339585 0.430403 0.582440 01:30
4 4.353267 5.346978 0.857710 0.716170 0.152878 0.308072 0.392022 0.536225 01:30
5 4.252083 4.826119 0.869468 0.691910 0.178491 0.347969 0.438572 0.523875 01:30
6 4.414810 5.091223 0.861025 0.704273 0.169114 0.354011 0.462759 0.594171 01:30
7 4.358559 4.679442 0.808027 0.708888 0.184589 0.347489 0.442850 0.563372 01:30
8 3.914483 4.799042 0.839595 0.694164 0.196506 0.373271 0.485033 0.605448 01:30
9 3.563733 4.515018 0.802791 0.702344 0.190167 0.366735 0.469429 0.583833 01:30
learn.save('seg_reg_full_real_1')
Path('models/seg_reg_full_real_1.pth')
learn.load('seg_reg_full_real_1')
<fastai.learner.Learner at 0x7fe9342206d0>
learn.unfreeze()
lrs = slice(1e-6,lr/10); lrs
slice(1e-06, 2.9999999999999997e-05, None)
learn.fit_flat_cos(10, lrs)
epoch train_loss valid_loss mae sr_acc_old sr_acc05 sr_acc1 sr_acc15 sr_acc2 time
0 3.799974 4.523741 0.813144 0.700424 0.186073 0.355542 0.455982 0.573185 01:37
1 3.457279 4.358019 0.769677 0.710907 0.187344 0.362932 0.473585 0.595600 01:37
2 3.715190 4.513614 0.796626 0.704561 0.189954 0.359431 0.465413 0.584217 01:37
3 3.458338 4.380141 0.778749 0.708567 0.175998 0.358986 0.465622 0.575038 01:37
4 3.868106 4.303098 0.764538 0.709471 0.184271 0.361908 0.469391 0.582626 01:37
5 3.657902 4.268610 0.759626 0.713752 0.181757 0.357685 0.465071 0.584969 01:37
6 3.688116 4.171949 0.750382 0.712621 0.188429 0.363649 0.471641 0.599138 01:37
7 3.622725 4.149560 0.753431 0.710179 0.190857 0.363890 0.472001 0.604596 01:37
8 3.469848 4.088984 0.744515 0.713348 0.186447 0.364163 0.476983 0.599486 01:37
9 3.930444 4.059169 0.739791 0.713391 0.189513 0.367989 0.478807 0.593679 01:37
learn.save('seg_reg_full_real_2')
Path('models/seg_reg_full_real_2.pth')
learn.show_results(max_n=10, figsize=(18,8))  # todo: need to show 'soft' version of images

Inference (Full Size)

learn.load('seg_reg_full_real_2')
<fastai.learner.Learner at 0x7fe9342206d0>
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
print(preds.shape, targs.shape)
len(preds)
torch.Size([391, 1, 384, 512]) torch.Size([391, 384, 512])
391
seg_img_dir = 'seg_reg_images'
!rm -rf {seg_img_dir}; mkdir {seg_img_dir}
results = []
for i in range(len(preds)):
    #line_list = [dls.valid.items[i].stem]+[round(targs[i].cpu().numpy().item(),2), round(preds[i][0].cpu().numpy().item(),2), losses[i].cpu().numpy(), i]
    filestem = dls.valid.items[i].stem
    line_list = [filestem]+[losses[i].cpu().numpy(), i]
    save_tmask(preds[i], seg_img_dir+'/'+filestem+'_pred.png')
    results.append(line_list)

# store as pandas dataframe
res_df = pd.DataFrame(results, columns=['filename', 'loss','i'])
res_df = res_df.sort_values('loss', ascending=False)
res_df.to_csv('segreg_top_losses_real.csv', index=False)

Plot target then predicted:

print(dls.valid.items[0].stem+':')
plt.imshow(targs[0].cpu().numpy()*bin_size, vmin=0, vmax=11)
plt.colorbar()
print("Target:")
06240907_proc_01882:
Target:
print(dls.valid.items[0].stem+':')
#plt.imshow(targs[0].cpu().numpy()*bin_size, vmin=0, vmax=11)
plt.imshow(preds[0][0].cpu().numpy()*bin_size, vmin=0, vmax=11)
plt.colorbar()
print("Predicted:")
06240907_proc_01882:
Predicted:

Here's a plot where we're careful to scale the colors according to the ring counts:

for i in range(10):
    j = i + 200 # move away from zero because I keep seeing the same thing ;-)
    print(dls.valid.items[j].stem+': (targ, pred)')
    fig, axarr = plt.subplots(1,2, figsize=(12,4))
    axarr[0].imshow(targs[j].cpu().numpy()*bin_size, vmin=0, vmax=11)
    axarr[1].imshow(preds[j][0].cpu().numpy()*bin_size, vmin=0, vmax=11)
    plt.show()
06241902_proc_01195: (targ, pred)
06241902_proc_01470: (targ, pred)
06240907_proc_01714: (targ, pred)
06240907_proc_00774: (targ, pred)
06240907_proc_01403: (targ, pred)
06240907_proc_01868: (targ, pred)
06241902_proc_01942: (targ, pred)
06241902_proc_01121: (targ, pred)
06240907_proc_00867: (targ, pred)
06240907_proc_01021: (targ, pred)