Acknowledgement: I took Zach Mueller's Image Segmentation tutoral notebook (based on the main FastAI lesson notebook) and modified it to do regression (as per Zach's suggestions) and to work with my own data.
Note: The WandB links will 404, because there is no "drscotthawley" WandB account. We just used
sed
to replace the real username in the .ipynb files.
!pip install -Uqq fastai espiownage==0.0.45 mrspuff typing_extensions -q --upgrade
import espiownage
from espiownage.core import *
sysinfo()
print(f"espiownage version {espiownage.__version__}")
from fastai.vision.all import *
from fastcore.xtras import Path
from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos
from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize
from fastai.layers import Mish # MishJIT gives me trouble :-(
from fastai.losses import BaseLoss, MSELossFlat, CrossEntropyLossFlat, BCEWithLogitsLossFlat
from fastai.optimizer import ranger
from fastai.torch_core import tensor
from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner
from PIL import Image
import numpy as np
import random
from torch import nn
from torchvision.models.resnet import resnet34
import torch
import torch.nn.functional as F
import glob
from pathlib import Path
dataset_name = 'cleaner' # choose from: cleaner, preclean, spnet, cyclegan, fake
project = 'segreg_kfold'
path = get_data(dataset_name)
bin_size = 0.7
maskdir = path / ('masks_'+str(bin_size))
# We can also generate masks dynamically using `espiownage`'s `gen_masks` script:
#!gen_masks --quiet --step={bin_size} --maskdir={maskdir} --files={str(path/'annotations')+'/*.csv'}
path_im = path/'images'
path_mask = path/maskdir
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
img_names = [meta_to_img_path(x, img_bank=path_im) for x in meta_names] # img_names
mask_names = sorted(get_image_files(path_mask))
print("lengths of input lists:",len(meta_names), len(img_names), len(mask_names))
# shuffle and check that things line up
# (precaution for DIY kfold split)
def shuffle_together(*ls):
"shuffle any number of lists in the same way"
l =list(zip(*ls))
random.shuffle(l)
return zip(*l)
random.seed(0) # so you can start again/elsewhere & keep going from the same 'shuffle'
img_names, meta_names, mask_names = shuffle_together(img_names, meta_names, mask_names)
#sanity checks:
assert len(img_names)==len(meta_names)
assert len(img_names)==len(mask_names)
for i in range(len(img_names)):
assert os.path.basename(meta_to_mask_path(meta_names[i],mask_dir=str(path_mask)+'/')) == os.path.basename(mask_names[i]), "mask and meta don't agree"
assert os.path.basename(meta_to_img_path(meta_names[i])) == os.path.basename(img_names[i]), f'{os.path.basename(meta_to_img_path(meta_names[i]))} != {os.path.basename(img_names[i])}'
print("\nThe following should match up with each other and also be SAME THING each time you restart this notebook:")
for x in [meta_names, img_names, mask_names]:
print(os.path.basename(x[0]))
^expected output:
06240907_proc_01617.csv
06240907_proc_01617.png
06240907_proc_01617_P.png
get_msk = lambda o: path/maskdir/f'{o.stem}_P{o.suffix}'
colors = list(range(int(11/bin_size) + 1))
print("colors = ",colors)
codes = [str(n) for n in range(len(colors))];
print("codes = ",codes)
yrange = len(codes);
print("yrange = ",yrange)
sz = (384, 512)
half = tuple(int(x/2) for x in sz);
print("half = ",half)
def sr_acc_old(inp, targ): # scores both voids and objects
targ = targ.squeeze(1)
return 1 - (inp-targ).abs().round().clamp(max=1).mean()
def sr_acc(inp, targ, bin_size=1):
"segmentation regression accuracy: Are we within +/- bin_size? tries to score only objects, not voids"
targ = targ.squeeze(1)
inp,targ = flatten_check(inp,targ) # https://docs.fast.ai/metrics.html#flatten_check
mask = targ != void_code # non_voids
if len(targ[mask]) == 0: # Empty image (all void)
where_correct = (inp-targ).abs() < bin_size # gonna be ~100%!
else:
where_correct = (inp[mask]-targ[mask]).abs() < bin_size # don't count voids in metric
return where_correct.float().mean()
# Cell
def sr_acc05(inp, targ): return sr_acc(inp, targ, bin_size=0.5)
def sr_acc07(inp, targ): return sr_acc(inp, targ, bin_size=0.7)
def sr_acc1(inp, targ): return sr_acc(inp, targ, bin_size=1)
def sr_acc15(inp, targ): return sr_acc(inp, targ, bin_size=1.5)
def sr_acc2(inp, targ): return sr_acc(inp, targ, bin_size=2)
!pip install wandb -qqq
import wandb
from fastai.callback.wandb import *
wandb.login()
k = 3 # choose 0 to 4
nk = 5
nv = int(len(img_names)/nk) # size of val set
bgn = k*nv # ind to start val set
inds = list(range(bgn, bgn+nv)) # indices for this val set
db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=IndexSplitter(inds),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=half, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
dls = db.dataloaders(path/'images', fnames=img_names, bs=4)
dls.vocab = codes
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['0']
opt = ranger
hrfac = 1.2 # 'headroom factor'
y_range=(0,int(len(codes)*hrfac)) # balance between "clamping" to range of real data vs too much "compression" from sigmoid nonlineari
#learn = unet_learner(dls, resnet34, yrange=len(codes), loss_func=MSELossFlat(), metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt)
metrics = [mae, sr_acc_old, sr_acc05, sr_acc07, sr_acc1, sr_acc15, sr_acc2]
# run parameters
epochs, lr = 12*4, 1e-3
wandb.init(project=project, name=f'k={k} {dataset_name}') # <-- let wandb make up names #name=f"k={k},e{epochs},lr{lr}")
learn = unet_learner(dls, resnet34, n_out=1, y_range=y_range, loss_func=MSELossFlat(),
metrics=metrics, self_attention=True, act_cls=Mish, opt_func=opt,
cbs=WandbCallback())
#lr = learn.lr_find().valley
#print("Suggested Learning Rate =",lr)
print("----- HALF SIZE TRAINING")
print("Training: frozen epochs...")
learn.fit_flat_cos(12, slice(lr)) # these frozen epochs don't yield much improvement btw
print("unfreezing model, lowering lr by 4")
learn.unfreeze()
lrs = slice(lr/400, lr/4)
print("Training: unfrozen epochs...")
learn.fit_flat_cos(12, lrs)
halfweights = 'seg_reg_real_half'
print(f"Saving model: {halfweights}")
learn.save(halfweights)
# Nope we're not finished! Save wandb.finish() until after Full size training.
print("\n----- FULL SIZE TRAINING -----")
db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=IndexSplitter(inds),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=sz, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
dls = db.dataloaders(path/'images', fnames=img_names, bs=2) # smaller batch size because we're now full size
dls.vocab = codes
learn = unet_learner(dls, resnet34, n_out=1, y_range=y_range, loss_func=MSELossFlat(),
metrics=metrics, self_attention=True, act_cls=Mish, opt_func=opt,
cbs=WandbCallback())
learn.load(halfweights)
#learn.lr_find(end_lr=5e-3)
lr = 3e-4
print("Training: frozen epochs...")
learn.fit_flat_cos(10, slice(lr))
print("unfreezing model, lowering lr by...stuff")
learn.unfreeze()
lrs = slice(1e-6,lr/10); lrs
print("Training: unfrozen epochs...")
learn.fit_flat_cos(10, lrs)
print("Finishing WandB")
wandb.finish()
fullweights = 'seg_reg_real_full'
print(f"Saving model: {fullweights}")
learn.save(fullweights)
learn.load(fullweights)
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
print(preds.shape, targs.shape)
len(preds)
def save_tmask(tmask, fname='', norm=False): # save tensor mask
tmask_new = tmask[0].squeeze().cpu().numpy()
use_min, use_max = 0, np.max(np.array(colors)) # use scale of max ring count
if norm: use_min, use_max = tmask_new.min(), tmask_new.max() # auto scale for just this image
rescaled = (255.0 / use_max * (tmask_new - use_min)).astype(np.uint8)
im = Image.fromarray(rescaled)
if fname != '': im.save(fname)
return im
seg_img_dir = 'seg_reg_images'
#!rm -rf {seg_img_dir}; # leave 'em
! mkdir {seg_img_dir}
results = []
for i in range(len(preds)):
#line_list = [dls.valid.items[i].stem]+[round(targs[i].cpu().numpy().item(),2), round(preds[i][0].cpu().numpy().item(),2), losses[i].cpu().numpy(), i]
filestem = dls.valid.items[i].stem
line_list = [filestem]+[losses[i].cpu().numpy(), i]
save_tmask(preds[i], seg_img_dir+'/'+filestem+'_pred.png')
results.append(line_list)
# store as pandas dataframe
res_df = pd.DataFrame(results, columns=['filename', 'loss','i'])
res_df = res_df.sort_values('loss', ascending=False) # top loss order
res_df.to_csv(f'segreg_top_losses_real_k{k}.csv', index=False)