Acknowledgement: I took Zach Mueller's Image Segmentation tutoral notebook (based on the main FastAI lesson notebook) and modified it to do regression (as per Zach's suggestions) and to work with my own data.
!pip install -Uqq espiownage
from fastai.vision.all import *
from espiownage.core import *
Below you will find the exact imports for everything we use today
from fastcore.xtras import Path
from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos
from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize
from fastai.layers import Mish # MishJIT gives me trouble :-(
from fastai.losses import BaseLoss, MSELossFlat, CrossEntropyLossFlat, BCEWithLogitsLossFlat
from fastai.optimizer import ranger
from fastai.torch_core import tensor
from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner
from PIL import Image
import numpy as np
import random
from torch import nn
from torchvision.models.resnet import resnet34
import torch
import torch.nn.functional as F
import glob
from pathlib import Path
import torch, re
tv, cv = torch.__version__, torch.version.cuda
tv = re.sub('\+cu.*','',tv)
TORCH_VERSION = 'torch'+tv[0:-1]+'0'
CUDA_VERSION = 'cu'+cv.replace('.','')
print(f"TORCH_VERSION={TORCH_VERSION}; CUDA_VERSION={CUDA_VERSION}")
print(f"CUDA available = {torch.cuda.is_available()}, Device count = {torch.cuda.device_count()}, Current device = {torch.cuda.current_device()}")
print(f"Device name = {torch.cuda.get_device_name()}")
print("hostname:")
!hostname
path = Path('/home/drscotthawley/datasets/espiownage-cleaner/') # real data is local and private
We can generate masks dynamically using espiownage
's gen_masks
script:
bin_size = 0.7
maskdir = path / ('masks_'+str(bin_size))
!gen_masks --quiet --step={bin_size} --maskdir={maskdir} --files={str(path/'annotations')+'/*.csv'}
path_im = path/'images'
path_lbl = path/maskdir
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
random.shuffle(fnames)
lbl_names = get_image_files(path_lbl)
len(meta_names), len(fnames), len(lbl_names)
img_fn = fnames[1]
img = PILImage.create(img_fn)
img.show(figsize=(5,5))
get_msk = lambda o: path/maskdir/f'{o.stem}_P{o.suffix}'
The stem and suffix grab everything before and after the period respectively.
The segmentation masks are not floating point values, rather they're integers obtained by "binning" the ring counts by bin_size, then generating integers as int(ring_count/bin_size).
Our masks are of type PILMask
and we will make our gradient percentage (alpha) equal to 1 as we are not overlaying this on anything yet
msk = PILMask.create(get_msk(img_fn))
msk.show(figsize=(5,5), alpha=1)
Now if we look at what our mask actually is, we can see it's a giant array of pixels:
tensor(msk)
set(np.array(msk).flatten())
the "colors" are the integer values of the quantized ring counts (rescaled by the bin_size)
colors = list(range(int(11/bin_size) + 1))
colors
....and because this is based on a classification model, there are text label "codes" (e.g. "dog", "cat") which for us are just the bin-integer mask values, all over again:
codes = [str(n) for n in range(len(colors))]; codes
yrange = len(codes); yrange
This first round we will train at half the image size
sz = msk.shape; sz
half = tuple(int(x/2) for x in sz); half
db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=half, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
dls = db.dataloaders(path/'images', fnames=fnames, bs=4)
Let's look at a batch, and look at all the classes between codes 1 and 30 (ignoring Animal
and Wall
)
dls.show_batch(max_n=4, vmin=1, vmax=len(codes), figsize=(14,10))
Lastly let's make our vocabulary a part of our DataLoaders
, as our loss function needs to deal with the Void
label
dls.vocab = codes
name2id = {v:k for k,v in enumerate(codes)}
name2id
void_code = name2id['0']
def sr_acc_old(inp, targ): # scores both voids and objects
targ = targ.squeeze(1)
return 1 - (inp-targ).abs().round().clamp(max=1).mean()
def sr_acc(inp, targ, bin_size=1):
"segmentation regression accuracy: Are we within +/- bin_size? tries to score only objects, not voids"
targ = targ.squeeze(1)
inp,targ = flatten_check(inp,targ) # https://docs.fast.ai/metrics.html#flatten_check
mask = targ != void_code # non_voids
if len(targ[mask]) == 0: # Empty image (all void)
where_correct = (inp-targ).abs() < bin_size # gonna be ~100%!
else:
where_correct = (inp[mask]-targ[mask]).abs() < bin_size # don't count voids in metric
return where_correct.float().mean()
# Cell
def sr_acc05(inp, targ): return sr_acc(inp, targ, bin_size=0.5)
# Cell
def sr_acc1(inp, targ): return sr_acc(inp, targ, bin_size=1)
# Cell
def sr_acc15(inp, targ): return sr_acc(inp, targ, bin_size=1.5)
# Cell
def sr_acc2(inp, targ): return sr_acc(inp, targ, bin_size=2)
opt = ranger
hrfac = 1.2 # 'headroom factor'
y_range=(0,int(len(codes)*hrfac)) # balance between "clamping" to range of real data vs too much "compression" from sigmoid nonlineari
metrics = [mae, sr_acc_old, sr_acc05, sr_acc1, sr_acc15, sr_acc2]
learn = unet_learner(dls, resnet34, n_out=1, y_range=y_range, loss_func=MSELossFlat(), metrics=metrics, self_attention=True, act_cls=Mish, opt_func=opt)
lr = learn.lr_find().valley
print("Suggested Learning Rate =",lr)
lr = 1e-3
With our new optimizer, we will also want to use a different fit function, called fit_flat_cos
learn.fit_flat_cos(12, slice(lr)) # these frozen epochs don't yield much improvement btw
learn.save('seg_reg_real_1')
learn.load('seg_reg_real_1')
learn.show_results(max_n=6, figsize=(10,10))
^^ if the right column above is all blue, that's bad. but let's unfreeze and see if the model can learn..
Let's unfreeze the model and decrease our learning rate by 4 (Rule of thumb)
lrs = slice(lr/400, lr/4)
lr, lrs
learn.unfreeze()
And train for a bit more
learn.fit_flat_cos(12, lrs)
learn.save('seg_reg_real_2')
learn.show_results(max_n=6, figsize=(10,10))
dl = learn.dls.test_dl(fnames[:5])
dl.show_batch()
Let's do the first five pictures
preds = learn.get_preds(dl=dl)
preds[0].shape
pred_1 = preds[0][0].squeeze()
pred_1.shape
msk = PILMask.create(pred_1)
msk.show(figsize=(5,5), alpha=1)
pred_arx = pred_1.numpy()
def save_tmask(tmask, fname='', norm=False): # save tensor mask
tmask_new = tmask[0].squeeze().cpu().numpy()
use_min, use_max = 0, np.max(np.array(colors)) # use scale of max ring count
if norm: use_min, use_max = tmask_new.min(), tmask_new.max() # auto scale for just this image
rescaled = (255.0 / use_max * (tmask_new - use_min)).astype(np.uint8)
im = Image.fromarray(rescaled)
if fname != '': im.save(fname)
return im
im = save_tmask(preds[0])
im
plt.imshow(im)
plt.colorbar()
Here we have actual ring values. Note that without vmin=...,vmax=...
kwargs, plt.imshow
will auto-normalize the colors to take up the maximum range for this image alone.
plt.imshow(pred_arx*bin_size)
plt.colorbar()
Actually that's not too bad!
db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=sz, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
dls = db.dataloaders(path/'images', fnames=fnames, bs=2) # smaller batch size because we're now full size
dls.vocab = codes
learn = unet_learner(dls, resnet34, n_out=1, y_range=y_range, loss_func=MSELossFlat(), metrics=metrics, self_attention=True, act_cls=Mish, opt_func=opt)
learn.load('seg_reg_real_2')
learn.lr_find(end_lr=5e-3)
lr = 3e-4
learn.fit_flat_cos(10, slice(lr))
learn.save('seg_reg_full_real_1')
learn.load('seg_reg_full_real_1')
learn.unfreeze()
lrs = slice(1e-6,lr/10); lrs
learn.fit_flat_cos(10, lrs)
learn.save('seg_reg_full_real_2')
learn.show_results(max_n=10, figsize=(18,8)) # todo: need to show 'soft' version of images
learn.load('seg_reg_full_real_2')
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
print(preds.shape, targs.shape)
len(preds)
seg_img_dir = 'seg_reg_images'
!rm -rf {seg_img_dir}; mkdir {seg_img_dir}
results = []
for i in range(len(preds)):
#line_list = [dls.valid.items[i].stem]+[round(targs[i].cpu().numpy().item(),2), round(preds[i][0].cpu().numpy().item(),2), losses[i].cpu().numpy(), i]
filestem = dls.valid.items[i].stem
line_list = [filestem]+[losses[i].cpu().numpy(), i]
save_tmask(preds[i], seg_img_dir+'/'+filestem+'_pred.png')
results.append(line_list)
# store as pandas dataframe
res_df = pd.DataFrame(results, columns=['filename', 'loss','i'])
res_df = res_df.sort_values('loss', ascending=False)
res_df.to_csv('segreg_top_losses_real.csv', index=False)
Plot target then predicted:
print(dls.valid.items[0].stem+':')
plt.imshow(targs[0].cpu().numpy()*bin_size, vmin=0, vmax=11)
plt.colorbar()
print("Target:")
print(dls.valid.items[0].stem+':')
#plt.imshow(targs[0].cpu().numpy()*bin_size, vmin=0, vmax=11)
plt.imshow(preds[0][0].cpu().numpy()*bin_size, vmin=0, vmax=11)
plt.colorbar()
print("Predicted:")
Here's a plot where we're careful to scale the colors according to the ring counts:
for i in range(10):
j = i + 200 # move away from zero because I keep seeing the same thing ;-)
print(dls.valid.items[j].stem+': (targ, pred)')
fig, axarr = plt.subplots(1,2, figsize=(12,4))
axarr[0].imshow(targs[j].cpu().numpy()*bin_size, vmin=0, vmax=11)
axarr[1].imshow(preds[j][0].cpu().numpy()*bin_size, vmin=0, vmax=11)
plt.show()