This is a straight-up COPY of Zach Mueller's Walk With Fast AI Lesson 4 - Image Segmentation, EXCEPT we'll make it into a segmentation REGRESSION model (with one "class"), instead of classification (with 30 classes). The dataset is CAMVID.
(In a different notebook TBD, I'll switch to my own data.)
from fastai.vision.all import *
Below you will find the exact imports for everything we use today
from fastcore.xtras import Path
from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos
from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize
from fastai.layers import Mish # MishJIT gives me trouble :-(
from fastai.losses import BaseLoss, MSELossFlat, CrossEntropyLossFlat, BCEWithLogitsLossFlat
from fastai.optimizer import ranger
from fastai.torch_core import tensor
from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner
from PIL import Image
import numpy as np
from torch import nn
from torchvision.models.resnet import resnet34
import torch
import torch.nn.functional as F
!nvidia-smi
torch.cuda.is_available()
path = untar_data(URLs.CAMVID)
Our validation set is inside a text document called valid.txt
and split by new lines. Let's read it in:
valid_fnames = (path/'valid.txt').read_text().split('\n')
valid_fnames[:5]
Let's look at an image and see how everything aligns up
path_im = path/'images'
path_lbl = path/'labels'
First we need our filenames
fnames = get_image_files(path_im)
lbl_names = get_image_files(path_lbl)
And now let's work with one of them
img_fn = fnames[10]
img = PILImage.create(img_fn)
img.show(figsize=(5,5))
Now let's grab our y's. They live in the labels
folder and are denoted by a _P
get_msk = lambda o: path/'labels'/f'{o.stem}_P{o.suffix}'
The stem and suffix grab everything before and after the period respectively.
Our masks are of type PILMask
and we will make our gradient percentage (alpha) equal to 1 as we are not overlaying this on anything yet
msk = PILMask.create(get_msk(img_fn))
msk.show(figsize=(5,5), alpha=1)
Now if we look at what our mask actually is, we can see it's a giant array of pixels:
tensor(msk)
Where each one represents a class that we can find in codes.txt
. Let's make a vocabulary with it
codes = np.loadtxt(path/'codes.txt', dtype=str); codes
yrange = len(codes); yrange
We need a split function that will split from our list of valid filenames we grabbed earlier. Let's try making our own.
def FileSplitter(fname):
"Split `items` depending on the value of `mask`."
valid = Path(fname).read_text().split('\n')
def _func(x): return x.name in valid
def _inner(o, **kwargs): return FuncSplitter(_func)(o)
return _inner
splits between training and validation sets. FuncSplitter returns true for validation set items, false otherwise.
This takes in our filenames, and checks for all of our filenames in all of our items in our validation filenames
This first round we will train at half the image size
sz = msk.shape; sz
half = tuple(int(x/2) for x in sz); half
camvid = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=FileSplitter(path/'valid.txt'),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])
dls = camvid.dataloaders(path/'images', bs=4)
Let's look at a batch, and look at all the classes between codes 1 and 30 (ignoring Animal
and Wall
)
dls.show_batch(max_n=4, vmin=1, vmax=30, figsize=(14,10))
Lastly let's make our vocabulary a part of our DataLoaders
, as our loss function needs to deal with the Void
label
dls.vocab = codes
Now we need a methodology for grabbing that particular code from our output of numbers. Let's make everything into a dictionary
name2id = {v:k for k,v in enumerate(codes)}
name2id
Awesome! Let's make an accuracy function
void_code = name2id['Void']
For segmentation, we want to squeeze all the outputted values to have it as a matrix of digits for our segmentation mask. From there, we want to match their argmax to the target's mask for each pixel and take the average
def acc_camvid(inp, targ):
targ = targ.squeeze(1)
mask = targ != void_code
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
def acc_camvid_reg(inp, targ):
targ = targ.squeeze(1)
# 1 - pixelwise error rate (within +/- 0.5 of each other). note this includes voids
#return 1 - (inp-targ).abs().round().type(torch.IntTensor).type(torch.FloatTensor).mean() # nope error's too big
#return 1 - torch.count_nonzero( (inp-targ).abs().round().type(torch.IntTensor) ) / inp.numel()
return 1 - (inp-targ).abs().round().clamp(max=1).mean()
U-Net allows us to look at pixel-wise representations of our images through sizing it down and then blowing it bck up into a high resolution image. The first part we call an "encoder" and the second a "decoder"
On the image, the authors of the UNET paper describe the arrows as "denotions of different operations"
We have a special unet_learner
. Something new is we can pass in some model configurations where we can declare a few things to customize it with!
- Blur/blur final: avoid checkerboard artifacts
- Self attention: A self-attention layer
- y_range: Last activations go through a sigmoid for rescaling
- Last cross - Cross-connection with the direct model input
- Bottle - Bottlenck or not on that cross
- Activation function
- Norm type
Let's make a unet_learner
that uses some of the new state of the art techniques. Specifically:
- Self-attention layers:
self_attention = True
- Mish activation function:
act_cls = Mish
Along with this we will use the Ranger
as optimizer function.
opt = ranger
learn = unet_learner(dls, resnet34, n_out=1, y_range=(0,len(codes)), loss_func=MSELossFlat(), metrics=acc_camvid_reg, self_attention=True, act_cls=Mish, opt_func=opt)
learn.summary()
If we do a learn.summary
we can see this blow-up trend, and see that our model came in frozen. Let's find a learning rate
lr = learn.lr_find().valley
print("Suggested Learning Rate =",lr)
With our new optimizer, we will also want to use a different fit function, called fit_flat_cos
learn.fit_flat_cos(10, slice(lr))
^^ Looks like we're getting some overfitting, meaning the model's not learning well enough. :-(
Final accuracy score for this part of Zach's classification-segmentation notebook was > 91%.
learn.save('stage-1') # Zach saves in case Colab dies / gives OOM
learn.load('stage-1'); # he reloads as a way of skipping what came before if he restarts the runtime.
learn.show_results(max_n=4, figsize=(12,6))
Let's unfreeze the model and decrease our learning rate by 4 (Rule of thumb)
lrs = slice(lr/400, lr/4)
lr, lrs
learn.unfreeze()
And train for a bit more
learn.fit_flat_cos(12, lrs)
^Yuck. Definitely overfitting.
learn.save('model_1')
And look at a few results
learn.show_results(max_n=4, figsize=(18,8))
So it WORKS, just not nearly as cleanly as the classification version. (Compare similar images in Zach's original notebook)
dl = learn.dls.test_dl(fnames[:5])
dl.show_batch()
Let's do the first five pictures
preds = learn.get_preds(dl=dl)
preds[0].shape
Alright so we have a 5x1x360x480, just like we wanted
What does this mean? We had five images, so each one is one of our five images in our batch. Let's look at the first
pred_1 = preds[0][0].squeeze()
pred_1.shape
And look at the mask:
msk = PILMask.create(pred_1)
msk.show(figsize=(5,5), alpha=1)
What do we do from here? We need to save it away. We can do this one of two ways, as a numpy array to image, and as a tensor (to say use later rawly)
pred_arx = pred_arx.numpy()
rescaled = (255.0 / pred_arx.max() * (pred_arx - pred_arx.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im
im.save('test.png')
Let's make a function to do so for our files
for i, pred in enumerate(preds[0]):
pred_arg = pred.argmax(dim=0).numpy()
rescaled = (255.0 / pred_arg.max() * (pred_arg - pred_arg.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im.save(f'Image_{i}.png')
Now let's save away the raw:
torch.save(preds[0][0], 'Image_1.pt')
pred_1 = torch.load('Image_1.pt')
plt.imshow(pred_1.argmax(dim=0))
from fastai.vision.all import *
path = untar_data(URLs.CAMVID)
valid_fnames = (path/'valid.txt').read_text().split('\n')
get_msk = lambda o: path/'labels'/f'{o.stem}_P{o.suffix}'
codes = np.loadtxt(path/'codes.txt', dtype=str); codes
def FileSplitter(fname):
"Split `items` depending on the value of `mask`."
valid = Path(fname).read_text().split('\n')
def _func(x): return x.name in valid
def _inner(o, **kwargs): return FuncSplitter(_func)(o)
return _inner
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Void']
def acc_camvid(inp, targ):
targ = targ.squeeze(1)
mask = targ != void_code
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
And re-make our dataloaders. But this time we want our size to be the full size
camvid = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=FileSplitter(path/'valid.txt'),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=sz), Normalize.from_stats(*imagenet_stats)])
We'll also want to lower our batch size to not run out of memory
dls = camvid.dataloaders(path/"images", bs=1)
Let's assign our vocab, make our learner, and load our weights
opt = ranger
dls.vocab = codes
learn = unet_learner(dls, resnet34, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt)
learn.load('model_1');
And now let's find our learning rate and train!
learn.lr_find()
lr = 1e-3
learn.fit_flat_cos(10, slice(lr))
learn.save('full_1')
learn.unfreeze()
lrs = slice(1e-6,lr/10); lrs
learn.fit_flat_cos(10, lrs)
learn.save('full_2')
learn.show_results(max_n=4, figsize=(18,8))
We can use weighted loss functions to help with class imbalancing. We need to do this because simply oversampling won't quite work here! So, how do we do it? fastai
's CrossEntropyLossFlat
is just a wrapper around PyTorch
's CrossEntropyLoss
, so we can pass in a weight
parameter (even if it doesn't show up in our autocompletion!)
class CrossEntropyLossFlat(BaseLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
y_int = True
def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
def decodes(self, x): return x.argmax(dim=self.axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
But what should this weight be? It needs to be a 1xn
tensor, where n
is the number of classes in your dataset. We'll use a quick example, where all but the last class has a weight of 90% and the last class has a weight of 110%
Also, as we are training on the GPU, we need the tensor to be so as well:
weights = torch.tensor([[0.9]*31 + [1.1]]).cuda()
weights
Now we can pass this into CrossEntropyLossFlat
- Note: as this is segmentation, we need to make the axis to 1
learn.loss_func = CrossEntropyLossFlat(weight=weights, axis=1)
(or to pass it into cnn_learner
)
loss_func = CrossEntropyLossFlat(weight=weights, axis=1)
learn = unet_learner(dls, resnet34, metrics=acc_camvid, loss_func=loss_func)