import torch, re
tv, cv = torch.__version__, torch.version.cuda
tv = re.sub('\+cu.*','',tv)
TORCH_VERSION = 'torch'+tv[0:-1]+'0'
CUDA_VERSION = 'cu'+cv.replace('.','')
print(f"TORCH_VERSION={TORCH_VERSION}; CUDA_VERSION={CUDA_VERSION}")
print(f"CUDA available = {torch.cuda.is_available()}, Device count = {torch.cuda.device_count()}, Current device = {torch.cuda.current_device()}")
print(f"Device name = {torch.cuda.get_device_name()}")
print("hostname:")
!hostname
from fastai.vision.all import *
from espiownage.core import *
Below you will find the exact imports for everything we use today
from fastcore.xtras import Path
from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos
from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize
from fastai.layers import Mish
from fastai.losses import BaseLoss
from fastai.optimizer import ranger
from fastai.torch_core import tensor
from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner
from PIL import Image
import numpy as np
from torch import nn
from torchvision.models.resnet import resnet34
import torch
import torch.nn.functional as F
#path = untar_data('https://anonymized.machine.com/~drscotthawley/espiownage-cyclegan.tgz')
path = Path('/home/drscotthawley/datasets/espiownage-fake/')
#path = Path('/home/drscotthawley/datasets/espiownage-cleaner/') # not cleaned but clean-er than before!
Let's look at an image and see how everything aligns up
path_im = path/'images'
path_lbl = path/'masks'
First we need our filenames
import glob
#fnames = get_image_files(path_im)
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)
len(meta_names), len(fnames), len(lbl_names)
And now let's work with one of them
img_fn = fnames[10]
print(img_fn)
img = PILImage.create(img_fn)
img.show(figsize=(5,5))
Now let's grab our y's. They live in the labels
folder and are denoted by a _P
get_msk = lambda o: path/'masks'/f'{o.stem}_P{o.suffix}'
The stem and suffix grab everything before and after the period respectively.
Our masks are of type PILMask
and we will make our gradient percentage (alpha) equal to 1 as we are not overlaying this on anything yet
msk_name = get_msk(img_fn)
print(msk_name)
msk = PILMask.create(msk_name)
msk.show(figsize=(5,5), alpha=1)
Now if we look at what our mask actually is, we can see it's a giant array of pixels:
print(tensor(msk))
And just make sure that it's a simple file and not antialiasing. Let's see what values it contains:
set(np.array(msk).flatten())
Where each one represents a class that we can find in codes.txt
. Let's make a vocabulary with it
#colors = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110]
#colors = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
colors = list(set(np.array(msk).flatten()))
codes = [str(n) for n in range(len(colors))]; codes
We need a split function that will split from our list of valid filenames we grabbed earlier. Let's try making our own.
This takes in our filenames, and checks for all of our filenames in all of our items in our validation filenames
This first round we will train at half the image size
sz = msk.shape; sz
half = tuple(int(x/2) for x in sz); half
cyclegan = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])
dls = cyclegan.dataloaders(path/'images', fnames=fnames, bs=4)
Let's look at a batch, and look at all the classes
dls.show_batch(max_n=4, vmin=0, vmax=2, figsize=(14,10))
Lastly let's make our vocabulary a part of our DataLoaders
, as our loss function needs to deal with the Void
label
dls.vocab = codes
Now we need a methodology for grabbing that particular code from our output of numbers. Let's make everything into a dictionary
name2id = {v:int(v) for k,v in enumerate(codes)}
name2id
Awesome! Let's make an accuracy function
void_code = name2id['0'] # name2id['Void']
For segmentation, we want to squeeze all the outputted values to have it as a matrix of digits for our segmentation mask. From there, we want to match their argmax to the target's mask for each pixel and take the average
def acc_camvid(inp, targ):
targ = targ.squeeze(1)
mask = targ != void_code
if len(targ[mask]) == 0: mask = (targ == void_code) # Empty image (all void)
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
Let's make a unet_learner
that uses some of the new state of the art techniques. Specifically:
- Self-attention layers:
self_attention = True
- Mish activation function:
act_cls = Mish
Along with this we will use the Ranger
as optimizer function.
opt = ranger
learn = unet_learner(dls, resnet34, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt)
learn.lr_find()
lr = 1e-4
With our new optimizer, we will also want to use a different fit function, called fit_flat_cos
learn.fit_flat_cos(12, slice(lr))
learn.save('stage-1-real-fake2') # Zach saves in case Colab dies / gives OOM
learn.load('stage-1-real-fake2'); # he reloads as a way of skipping what came before if he restarts the runtime.
learn.show_results(max_n=4, figsize=(12,6))
Let's unfreeze the model and decrease our learning rate by 4 (Rule of thumb)
lrs = slice(lr/400, lr/4)
lr, lrs
learn.unfreeze()
And train for a bit more
learn.fit_flat_cos(12, lrs)
Now let's save that model away
learn.save('model_1_fake2')
learn.load('model_1_fake2')
And look at a few results
learn.show_results(max_n=6, figsize=(10,10))
dl = learn.dls.test_dl(fnames[0:6])
dl.show_batch()
Let's do the first five pictures
preds = learn.get_preds(dl=dl)
len(preds)
preds[0].shape
Alright so we have a 5x32x360x480
len(codes)
What does this mean? We had five images, so each one is one of our five images in our batch. Let's look at the first
ind = 5
pred_1 = preds[0][ind]
pred_1.shape
Now let's take the argmax of our values
pred_arx = pred_1.argmax(dim=0)
And look at it
plt.imshow(pred_arx)
What do we do from here? We need to save it away. We can do this one of two ways, as a numpy array to image, and as a tensor (to say use later rawly)
pred_arx = pred_arx.numpy()
rescaled = (255.0 / pred_arx.max() * (pred_arx - pred_arx.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im
im.save('test_fake2.png')
Let's make a function to do so for our files
for i, pred in enumerate(preds[0]):
pred_arg = pred.argmax(dim=0).numpy()
rescaled = (255.0 / pred_arg.max() * (pred_arg - pred_arg.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im.save(f'Image_{i}_fake2.png')
Now let's save away the raw:
torch.save(preds[0][ind], 'Image_1_fake2.pt')
pred_1 = torch.load('Image_1_fake2.pt')
plt.imshow(pred_1.argmax(dim=0))
from fastai.vision.all import *
from espiownage.core import *
import glob
path = Path('/home/drscotthawley/datasets/espiownage-fake/')
path_im = path/'images'
path_lbl = path/'masks'
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)
get_msk = lambda o: path/'masks'/f'{o.stem}_P{o.suffix}'
colors = [0, 1]
codes = [str(n) for n in colors]; codes
sz = (384, 512)
name2id = {v:int(v) for k,v in enumerate(codes)}
void_code = name2id['0']
And re-make our dataloaders. But this time we want our size to be the full size
seg_db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=sz), Normalize.from_stats(*imagenet_stats)])
We'll also want to lower our batch size to not run out of memory
dls = seg_db.dataloaders(path/"images", fnames=fnames, bs=2)
Let's assign our vocab, make our learner, and load our weights
opt = ranger
def acc_camvid2(inp, targ):
targ = targ.squeeze(1)
mask = targ != void_code
if len(targ[mask]) == 0: mask = (targ == void_code) # Empty image
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
dls.vocab = codes
learn = unet_learner(dls, resnet34, metrics=acc_camvid2, self_attention=True, act_cls=Mish, opt_func=opt)
learn.load('model_1_real');
And now let's find our learning rate and train!
learn.lr_find()
lr = 7e-5
learn.fit_flat_cos(10, slice(lr))
learn.save('seg_full_1_fake2')
learn.unfreeze()
lrs = slice(1e-6,lr/10); lrs
learn.fit_flat_cos(10, lrs)
learn.save('seg_full_2_fake2')
learn.show_results(max_n=4, figsize=(18,8))
interp = SegmentationInterpretation.from_learner(learn)
nplot, x = 10, 1
for i in interp.top_losses(nplot).indices:
print(f"[{x}] {dls.valid_ds.items[i]}")
x += 1
interp.plot_top_losses(k=nplot)
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
print(preds.shape, targs.shape)
len(preds)