!pip install -Uqq espiownage fastai
from fastai.vision.all import *
from espiownage.core import *
let's just get clear from the start which model we want:
checkpoint = get_checkpoint('segreg')
# TODO: move local directory to someplace user-downloadable
imgdir = '/home/shawley/datasets/other_espi_instruments/resized'
Below you will find the exact imports for everything we use today
from fastcore.xtras import Path
from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos
from fastai.data.block import DataBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize
from fastai.layers import Mish # MishJIT gives me trouble :-(
from fastai.losses import BaseLoss, MSELossFlat, CrossEntropyLossFlat, BCEWithLogitsLossFlat
from fastai.optimizer import ranger
from fastai.torch_core import tensor
from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner
from PIL import Image
import numpy as np
import random
from torch import nn
from torchvision.models.resnet import resnet34
import torch
import torch.nn.functional as F
import glob
from pathlib import Path
path = get_data('cleaner') # real data is local and private
bin_size = 0.7
maskdir = path / ('masks_'+str(bin_size))
path_im = path/'images'
path_lbl = path/maskdir
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
random.shuffle(fnames)
lbl_names = get_image_files(path_lbl)
get_msk = lambda o: path/maskdir/f'{o.stem}_P{o.suffix}'
colors = list(range(int(11/bin_size) + 1))
codes = [str(n) for n in range(len(colors))]; codes
sz = (384,512)
db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=sz, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
dls = db.dataloaders(path/'images', fnames=fnames, bs=16) # if OOM occurs lower bs to as low as 2
dls.vocab = codes
opt = ranger
hrfac = 1.2 # 'headroom factor'
codes = [str(n) for n in range(16)]; codes
y_range=(0,int(len(codes)*hrfac)) # balance between "clamping" to range of real data vs too much "compression" from sigmoid nonlineari
learn = unet_learner(dls, resnet34, n_out=1, y_range=y_range, loss_func=MSELossFlat(), self_attention=True, act_cls=Mish, opt_func=opt)
learn.load(str(checkpoint).replace('.pth',''))
from PIL import ImageOps
def save_tmask(tmask, fname='', norm=False, purple=False, blend_img=None): # save tensor mask
tmask_new = tmask[0].squeeze().cpu().numpy()
use_min, use_max = 0, np.max(np.array(colors)) # use scale of max ring count
if norm: use_min, use_max = tmask_new.min(), tmask_new.max() # auto scale for just this image
rescaled = (255.0 / use_max * (tmask_new - use_min)).astype(np.uint8)
im = Image.fromarray(rescaled)
if purple: im = ImageOps.colorize(im, black ="black", white =(255,0,255))
if blend_img is not None: im = Image.blend(blend_img, im, 0.4)
if fname != '': im.save(fname)
return im
#print(pfnames)
#print(pfnames[0],pfnames[5])
#print(len(pfnames))
Actually, there's a particular order I want to use, because I know how this turns out. So I'm going to specify:
good_fnames = ['/home/shawley/datasets/other_espi_instruments/resized/'+x
for x in ['musical_espi_moore_resized.jpg','musical_espi_wikipedia_resized.jpeg','lute1_resized.png',
'musical_espi_zooniverse_steelpan_resized.jpeg','musical_espi_guitar1.png']]
good_images = []
for f in good_fnames:
good_images.append(PILImage.create(f))
def my_get_preds(fnames, learn=None):
dlpred = dls.test_dl(fnames)
preds, _ = learn.get_preds(dl=dlpred)
print(preds.shape)
preds = preds.squeeze(1)
print(preds.shape) # after squeezing
return preds
preds = my_get_preds(good_images, learn=learn)
import matplotlib.gridspec as gridspec
def show_inp_segreg(inp, segreg, vmax=3, bin_size=0.7):
gs = gridspec.GridSpec(1,3,width_ratios=[5,5,0.5])
fig = plt.figure(figsize=(12,4))
#fig, axarr = plt.subplots(1,2, figsize=(12,4))
ax1 = fig.add_subplot(gs[0])
ax1.grid(False)
ax1.imshow(inp)
ax2 = fig.add_subplot(gs[1])
im = ax2.imshow(segreg.cpu().numpy()*bin_size, vmin=0, vmax=vmax)
cbar = fig.colorbar(im, cax=fig.add_subplot(gs[2]))
cbar.ax.tick_params(labelsize=25)
ax2.grid(False)
ax1.axis('off')
ax2.axis('off')
plt.show()
What you're about to see is that for close-up, grainy images, the seg-reg method does a fairly reasonable job:
for i in range(preds.shape[0]):
show_inp_segreg(good_images[i], preds[i])
But for images that are taken from futher away, or composite images with white-space borders (pulled straight from JASA Twitter!), or even just really 'clean'-looking laser holography images, it barely detects anything at all. ...For the composites and "small antinodes" in the images, it's because the model wasn't trained on it. For the holography image of the guitar below -- we don't know yet!
Here are some "harder images" that the seg-reg model doesn't do well on:
harder_files, harder_imgs = [], []
for f in sorted(glob.glob(imgdir+'/*')):
if f not in good_fnames:
harder_files.append(f)
harder_imgs.append(PILImage.create(f))
preds = my_get_preds(harder_files, learn=learn)
for i in range(preds.shape[0]):
show_inp_segreg(harder_imgs[i], preds[i])
checkpoint_file = 'seg_allone_full_real'
path = get_data('cleaner')
path_im = path/'images'
path_lbl = path/'masks'
#path = untar_data('https://anonymized.machine.com/~drscotthawley/espiownage-cleaner.tgz')
meta_names = sorted(glob.glob(str(path/'annotations')+'/*.csv'))
fnames = [meta_to_img_path(x, img_bank=path_im) for x in meta_names]
lbl_names = get_image_files(path_lbl)
get_msk = lambda o: path/'masks'/f'{o.stem}_P{o.suffix}'
colors = [0, 1]
codes = [str(n) for n in colors]; codes
sz = (384, 512)
name2id = {v:int(v) for k,v in enumerate(codes)}
void_code = name2id['0']
seg_db = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=sz, flip_vert=True), Normalize.from_stats(*imagenet_stats)])
dls = seg_db.dataloaders(path/"images", fnames=fnames, bs=2)
opt=ranger
learn = unet_learner(dls, resnet34, self_attention=True, act_cls=Mish, opt_func=opt)
learn.load(checkpoint_file)
Yay! We're in business. First let's try the good images:
preds = my_get_preds(good_images, learn=learn)
preds.shape
preds = preds.argmax(axis=1)
preds.shape
def show_inp_alloneseg(inp, seg, vmax=3, bin_size=0.7, colorbar=False):
fig, axarr = plt.subplots(1,2, figsize=(12,4))
axarr[0].imshow(inp)
im = axarr[1].imshow(seg.cpu().numpy())
if colorbar: plt.colorbar(im, ax=axarr[1])
axarr[0].axis('off')
axarr[1].axis('off')
plt.show()
for i in range(preds.shape[0]):
show_inp_alloneseg(good_images[i], preds[i])
Now for the harder images:
Maybe we could go back and crop those images to zoom in on the antinodes better. Not sure right now. We'll keep trying.
preds = my_get_preds(harder_imgs, learn=learn)
preds = preds.argmax(axis=1)
for i in range(preds.shape[0]):
show_inp_alloneseg(harder_imgs[i], preds[i])
Concluding Remarks
It's not inconceivable that with a little Transfer Learning from our model(s) using a small training set of annotations for a new instrument, that it could learn to annotate and count rings in these other systems. We would like to explore this in the future and see other enthusiast-researchers try to do the same. Let us know how it goes! Perhaps together we can build a more general model-tool for doing inference all all sorts of ESPI (and even holography?) images...
TODO: Add citation/copyright info for images