Here we take Zach Mueller's CAMVID Segmentation Tutorial and try to segment our real data as object vs background ("all one" class rather than multiple classes)
import torch, re
tv, cv = torch.__version__, torch.version.cuda
tv = re.sub('\+cu.*','',tv)
TORCH_VERSION = 'torch'+tv[0:-1]+'0'
CUDA_VERSION = 'cu'+cv.replace('.','')
print(f"TORCH_VERSION={TORCH_VERSION}; CUDA_VERSION={CUDA_VERSION}")
print(f"CUDA available = {torch.cuda.is_available()}, Device count = {torch.cuda.device_count()}, Current device = {torch.cuda.current_device()}")
print(f"Device name = {torch.cuda.get_device_name()}")
print("hostname:")
!hostname
from fastai.vision.all import *
from espiownage.core import *
Below you will find the exact imports for everything we use today
from fastcore.xtras import Path
from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos
from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize
from fastai.layers import Mish
from fastai.losses import BaseLoss
from fastai.optimizer import ranger
from fastai.torch_core import tensor
from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner
from PIL import Image
import numpy as np
from torch import nn
from torchvision.models.resnet import resnet34
import torch
import torch.nn.functional as F
path = get_data('cleaner'); path
Let's look at an image and see how everything aligns up
path_im = path/'images'
path_lbl = path/'masks'
First we need our filenames
import glob
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)
len(meta_names), len(fnames), len(lbl_names)
And now let's work with one of them
img_fn = fnames[10]
print(img_fn)
img = PILImage.create(img_fn)
img.show(figsize=(5,5))
Now let's grab our y's. They live in the labels
folder and are denoted by a _P
get_msk = lambda o: path/'masks'/f'{o.stem}_P{o.suffix}'
The stem and suffix grab everything before and after the period respectively.
Our masks are of type PILMask
and we will make our gradient percentage (alpha) equal to 1 as we are not overlaying this on anything yet
msk_name = get_msk(img_fn)
print(msk_name)
msk = PILMask.create(msk_name)
msk.show(figsize=(5,5), alpha=1)
Now if we look at what our mask actually is, we can see it's a giant array of pixels:
print(tensor(msk))
And just make sure that it's a simple file and not antialiasing. Let's see what values it contains:
set(np.array(msk).flatten())
Where each one represents a class that we can find in codes.txt
. Let's make a vocabulary with it
#colors = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110]
#colors = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
colors = list(set(np.array(msk).flatten()))
codes = [str(n) for n in range(len(colors))]; codes
We need a split function that will split from our list of valid filenames we grabbed earlier. Let's try making our own.
This takes in our filenames, and checks for all of our filenames in all of our items in our validation filenames
This first round we will train at half the image size
sz = msk.shape; sz
half = tuple(int(x/2) for x in sz); half
db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=half, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
dls = db.dataloaders(path/'images', fnames=fnames, bs=4)
Let's look at a batch, and look at all the classes
dls.show_batch(max_n=4, vmin=0, vmax=2, figsize=(14,10))
Lastly let's make our vocabulary a part of our DataLoaders
, as our loss function needs to deal with the Void
label
dls.vocab = codes
Now we need a methodology for grabbing that particular code from our output of numbers. Let's make everything into a dictionary
name2id = {v:int(v) for k,v in enumerate(codes)}
name2id
Awesome! Let's make an accuracy function
void_code = name2id['0'] # name2id['Void']
For segmentation, we want to squeeze all the outputted values to have it as a matrix of digits for our segmentation mask. From there, we want to match their argmax to the target's mask for each pixel and take the average
def acc_camvid(inp, targ):
targ = targ.squeeze(1)
mask = targ != void_code
if len(targ[mask]) == 0: mask = (targ == void_code) # Empty image (all void)
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
U-Net allows us to look at pixel-wise representations of our images through sizing it down and then blowing it bck up into a high resolution image. The first part we call an "encoder" and the second a "decoder"
On the image, the authors of the UNET paper describe the arrows as "denotions of different operations"
We have a special unet_learner
. Something new is we can pass in some model configurations where we can declare a few things to customize it with!
- Blur/blur final: avoid checkerboard artifacts
- Self attention: A self-attention layer
- y_range: Last activations go through a sigmoid for rescaling
- Last cross - Cross-connection with the direct model input
- Bottle - Bottlenck or not on that cross
- Activation function
- Norm type
Let's make a unet_learner
that uses some of the new state of the art techniques. Specifically:
- Self-attention layers:
self_attention = True
- Mish activation function:
act_cls = Mish
Along with this we will use the Ranger
as optimizer function.
opt = ranger
learn = unet_learner(dls, resnet34, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt)
learn.summary()
If we do a learn.summary
we can see this blow-up trend, and see that our model came in frozen. Let's find a learning rate
learn.lr_find()
lr = 1e-4
With our new optimizer, we will also want to use a different fit function, called fit_flat_cos
learn.fit_flat_cos(12, slice(lr))
learn.save('stage-1-real') # Zach saves in case Colab dies / gives OOM
learn.load('stage-1-real'); # he reloads as a way of skipping what came before if he restarts the runtime.
learn.show_results(max_n=4, figsize=(12,6))
Let's unfreeze the model and decrease our learning rate by 4 (Rule of thumb)
lrs = slice(lr/400, lr/4)
lr, lrs
learn.unfreeze()
And train for a bit more
learn.fit_flat_cos(15, lrs)
Now let's save that model away
learn.save('seg_allone_half_real')
learn.load('seg_allone_half_real')
And look at a few results
learn.show_results(max_n=6, figsize=(10,10))
dl = learn.dls.test_dl(fnames[0:8])
dl.show_batch()
Let's do the first five pictures
preds = learn.get_preds(dl=dl)
len(preds)
preds[0].shape
Alright so we have a 5x32x360x480
len(codes)
What does this mean? We had five images, so each one is one of our five images in our batch. Let's look at the first
ind = 6
pred_1 = preds[0][ind]
pred_1.shape
Now let's take the argmax of our values
pred_arx = pred_1.argmax(dim=0)
And look at it
plt.imshow(pred_arx)
What do we do from here? We need to save it away. We can do this one of two ways, as a numpy array to image, and as a tensor (to say use later rawly)
pred_arx = pred_arx.numpy()
rescaled = (255.0 / pred_arx.max() * (pred_arx - pred_arx.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im
im.save('test_real.png')
Let's make a function to do so for our files
for i, pred in enumerate(preds[0]):
pred_arg = pred.argmax(dim=0).numpy()
rescaled = (255.0 / pred_arg.max() * (pred_arg - pred_arg.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im.save(f'Image_{i}_real.png')
Now let's save away the raw:
torch.save(preds[0][ind], 'Image_1_real.pt')
pred_1 = torch.load('Image_1_real.pt')
plt.imshow(pred_1.argmax(dim=0))
from fastai.vision.all import *
from espiownage.core import *
import glob
path = Path('/home/drscotthawley/datasets/espiownage-cleaner/')
path_im = path/'images'
path_lbl = path/'masks'
#path = untar_data('https://anonymized.machine.com/~drscotthawley/espiownage-cleaner.tgz')
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)
get_msk = lambda o: path/'masks'/f'{o.stem}_P{o.suffix}'
colors = [0, 1]
codes = [str(n) for n in colors]; codes
sz = (384, 512)
name2id = {v:int(v) for k,v in enumerate(codes)}
void_code = name2id['0']
And re-make our dataloaders. But this time we want our size to be the full size
seg_db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=sz, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
We'll also want to lower our batch size to not run out of memory
dls = seg_db.dataloaders(path/"images", fnames=fnames, bs=2)
Let's assign our vocab, make our learner, and load our weights
opt = ranger
def acc_camvid2(inp, targ, warn=False):
targ = targ.squeeze(1)
mask = targ != void_code # where it's nonzero
if len(targ[mask]) == 0: # Empty image (all void)
mask = (targ == void_code)
if warn:
acc_empty = (inp.argmax(dim=1)[mask]==targ[mask]).float().mean() # score based on what's correct overall (~100%?)
print("Empty image, acc_empty = ",acc_empty.cpu().numpy())
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
def acc_camvid3(inp, targ):
mask = inp.argmax(dim=1) == targ.squeeze(1) # could give inflated scores for images dominated by void
return mask.float().mean()
dls.vocab = codes
learn = unet_learner(dls, resnet34, metrics=acc_camvid2, self_attention=True, act_cls=Mish, opt_func=opt)
learn.load('seg_allone_half_real');
And now let's find our learning rate and train!
learn.lr_find()
lr = 1e-4
learn.fit_flat_cos(12, slice(lr))
learn.save('seg_full_1_real') # save a checkpoint just in case we need to restart form here
learn.unfreeze()
lrs = slice(1e-6,lr/10); lrs
learn.fit_flat_cos(12, lrs)
learn.save('seg_allone_full_real')
learn.show_results(max_n=4, figsize=(18,8))
interp = SegmentationInterpretation.from_learner(learn)
nplot, x = 10, 1
for i in interp.top_losses(nplot).indices:
print(f"[{x}] {dls.valid_ds.items[i]}")
x += 1
interp.plot_top_losses(k=nplot)
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
print(preds.shape, targs.shape)
len(preds)
def save_tmask(tmask, fname, argmax=True):
tmask_new = tmask.argmax(dim=0).cpu().numpy() if argmax else tmask.cpu().numpy()
rescaled = (255.0 / tmask_new.max() * (tmask_new - tmask_new.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im.save(fname)
seg_img_dir = 'seg_images'
!rm -rf {seg_img_dir}; mkdir {seg_img_dir}
results = []
for i in range(len(preds)):
#line_list = [dls.valid.items[i].stem]+[round(targs[i].cpu().numpy().item(),2), round(preds[i][0].cpu().numpy().item(),2), losses[i].cpu().numpy(), i]
filestem = dls.valid.items[i].stem
line_list = [filestem]+[losses[i].cpu().numpy(), i]
save_tmask(preds[i], seg_img_dir+'/'+filestem+'_pred.png')
#save_tmask(targs[i], seg_img_dir+'/'+filestem+'_targ.png', argmax=False) # already got targs as inputs!
results.append(line_list)
# store as pandas dataframe
res_df = pd.DataFrame(results, columns=['filename', 'loss','i'])
res_df = res_df.sort_values('loss', ascending=False)
res_df.to_csv('segmentation_allone_top_losses_real.csv', index=False)