Mod of Zach Muller's WWF 01_Custom.ipynb.

Here we'll take cropped images of antinodes and try to count the rings, by fashioning a regression model out of a one-class classification model and scaling the output sigmoid (via fastai's y_range parameter) so that our fitted values stay within the linear regime of the sigmoid.

And we also want to "clamp" our output between a min of about 0.2 rings and a max of 11 rings, because that's how the dataset was created; so sigmoid makes a good choice for this "clamping" too.

 

Installs & Imports

!pip install espiownage pillow==8.2 --upgrade -q
import espiownage
from espiownage.core import *
sysinfo()
print(f"espiownage version {espiownage.__version__}")
TORCH_VERSION=torch1.8.0; CUDA_VERSION=cu101
CUDA available = True, Device count = 1, Current device = 0
Device name = GeForce RTX 2080 Ti
hostname: lecun
espiownage version 0.0.48

And import our libraries

from fastai.vision.all import *

Below you will find the exact imports for everything we use today

from fastcore.foundation import L
from fastcore.xtras import Path # @patch'd properties to the Pathlib module

from fastai.callback.fp16 import to_fp16
from fastai.callback.schedule import fit_one_cycle, lr_find 
from fastai.data.external import untar_data, URLs

from fastai.data.block import RegressionBlock, DataBlock
from fastai.data.transforms import get_image_files, Normalize, RandomSplitter, parent_label

from fastai.interpret import ClassificationInterpretation
from fastai.learner import Learner # imports @patch'd properties to Learner including `save`, `load`, `freeze`, and `unfreeze`
from fastai.optimizer import ranger

from fastai.vision.augment import aug_transforms, RandomResizedCrop, Resize, ResizeMethod
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.learner import cnn_learner
from fastai.vision.utils import download_images, verify_images

import os

Run parameters

dataset_name = 'cleaner' # choose from: 
                            # - cleaner (*real* data that's clean-er than "preclean"), 
                            # - preclean (unedited aggregates of 15-or-more volunteers)
                            # - spnet,   (original SPNet Real dataset)
                            # - cyclegan (original SPNet CGSmall dataset)
                            # - fake (newer than SPNet fake, this includes non-int ring #s)
use_wandb = False        # WandB.ai logging
project = 'count_in_crops' # project name for wandb
if use_wandb: 
    !pip install wandb -qqq
    import wandb
    from fastai.callback.wandb import *
    from fastai.callback.tracker import SaveModelCallback
    wandb.login()

Prepare Dataset

path = get_data(dataset_name) / 'crops'; path
Path('/home/shawley/.espiownage/data/espiownage-cleaner/crops')
fnames = get_image_files(path)  # image filenames
print(f"{len(fnames)} total cropped images")
ind = 1  # pick one cropped image
fnames[ind]
6614 total cropped images
Path('/home/shawley/.espiownage/data/espiownage-cleaner/crops/06241902_proc_00618_0_36_76_255_11.0.png')

For labels, we want the ring count which extract from the filename: It's the number between the last '_' and the '.png'

def label_func(x):  
    return round(float(x.stem.split('_')[-1]),2)

print(label_func(fnames[ind]))
11.0
cropsize = (300,300) # we will resize/reshape all input images to squares of this size
croppedrings = DataBlock(blocks=(ImageBlock, RegressionBlock(n_out=1)),
                    get_items=get_image_files,
                    splitter=RandomSplitter(),  # Note the random splitting. K-fold is another notebook
                    get_y=label_func,
                    item_tfms=Resize(cropsize, ResizeMethod.Squish),
                    batch_tfms=[*aug_transforms(size=cropsize, flip_vert=True, max_rotate=360.0), 
                    Normalize.from_stats(*imagenet_stats)])
# define dataloaders
dls = croppedrings.dataloaders(path, bs=32)

Take a look at sample target data. Notice how they're very circular! That's how we 'got away with' arbitrary (360 degree) rotations in the DataBlock's batch_tfms, above^.

dls.show_batch(max_n=9)

Train model

opt = ranger # optimizer the kids love these days

y_range=(0.2,13)  # balance between "clamping" to range of real data vs too much "compression" from sigmoid nonlinearity

if use_wandb:
    wandb.init(project=project, name=f'{dataset_name}')
    cbs = [WandbCallback()]
else:
    cbs = []

learn = cnn_learner(dls, resnet34, n_out=1, y_range=y_range, 
                    metrics=[mae, acc_reg05, acc_reg07, acc_reg1,acc_reg15,acc_reg2], 
                    loss_func=MSELossFlat(), opt_func=opt, cbs=cbs)
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /home/shawley/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth
learn.lr_find() # we're just going to use 5e-3 though
SuggestedLRs(lr_min=0.03630780577659607, lr_steep=1.0964781722577754e-06)
lr = 5e-3
epochs = 30 # 10-11 epochs is fine for lr=1e-2; here we do 30 w/ lower lr to see if we can "do better"
learn.fine_tune(epochs, lr, freeze_epochs=2)  
epoch train_loss valid_loss mae acc_reg05 acc_reg07 acc_reg1 acc_reg15 acc_reg2 time
0 1.298997 1.243957 0.718530 0.511346 0.664145 0.796520 0.897126 0.942511 00:22
1 1.517550 1.315192 0.750204 0.521936 0.661120 0.774584 0.869894 0.930408 00:22
epoch train_loss valid_loss mae acc_reg05 acc_reg07 acc_reg1 acc_reg15 acc_reg2 time
0 1.525957 1.293490 0.745135 0.496974 0.636914 0.764750 0.886536 0.931165 00:27
1 1.558694 1.256326 0.723578 0.515885 0.654312 0.782148 0.896369 0.936460 00:27
2 1.489346 1.263868 0.744777 0.493192 0.630862 0.771558 0.891831 0.934947 00:27
3 1.465624 1.289310 0.710750 0.544629 0.672466 0.795764 0.898638 0.937216 00:27
4 1.570813 1.316258 0.729875 0.502269 0.647504 0.785930 0.891831 0.938729 00:27
5 1.584235 1.290089 0.708399 0.560514 0.670953 0.789713 0.892587 0.931921 00:27
6 1.717140 1.521787 0.769059 0.513616 0.668684 0.777610 0.866868 0.926626 00:27
7 1.729477 1.338063 0.770505 0.472012 0.641452 0.757186 0.874433 0.928896 00:27
8 1.714307 1.514480 0.756737 0.530257 0.680787 0.793495 0.872920 0.920575 00:27
9 1.640525 1.411145 0.749479 0.524962 0.651286 0.781392 0.881241 0.926626 00:27
10 1.525180 1.568551 0.814826 0.440998 0.611952 0.757186 0.869894 0.921331 00:27
11 1.629610 1.360007 0.747457 0.517398 0.649017 0.775340 0.881997 0.931921 00:27
12 1.571880 1.253975 0.722277 0.515129 0.659607 0.782148 0.894100 0.932678 00:27
13 1.475989 1.346119 0.733487 0.520424 0.667171 0.789713 0.887292 0.934947 00:27
14 1.439114 1.272340 0.723906 0.523449 0.648260 0.780635 0.886536 0.935703 00:27
15 1.621275 1.245100 0.706564 0.543873 0.673222 0.782148 0.887292 0.937216 00:27
16 1.469432 1.272666 0.741773 0.493949 0.644478 0.773828 0.878971 0.934191 00:27
17 1.490656 1.454020 0.813164 0.449319 0.594554 0.736762 0.860817 0.923601 00:27
18 1.474666 1.254812 0.714538 0.535552 0.648260 0.773071 0.889561 0.935703 00:27
19 1.430211 1.364672 0.805295 0.422088 0.558245 0.732224 0.885779 0.932678 00:27
20 1.381298 1.198187 0.681863 0.570348 0.683056 0.795008 0.890318 0.938729 00:27
21 1.284741 1.229034 0.724479 0.514372 0.655824 0.773071 0.877458 0.936460 00:27
22 1.334033 1.196392 0.684910 0.559002 0.674735 0.795008 0.898638 0.940242 00:27
23 1.507739 1.229374 0.746835 0.484871 0.630862 0.784418 0.885779 0.937973 00:27
24 1.407409 1.253173 0.741548 0.490166 0.630862 0.777610 0.883510 0.937973 00:27
25 1.315197 1.229799 0.724383 0.515129 0.664902 0.791225 0.880484 0.935703 00:27
26 1.271938 1.210951 0.704713 0.550681 0.672466 0.795008 0.881241 0.938729 00:27
27 1.289100 1.254982 0.745295 0.486384 0.642965 0.773828 0.873676 0.935703 00:27
28 1.272271 1.248621 0.728451 0.510590 0.651286 0.778366 0.880484 0.936460 00:27
29 1.266804 1.262996 0.754478 0.465961 0.630106 0.778366 0.880484 0.937216 00:27

^ we could go back up and cut this off at 10, 15 or 20 epochs. In this case I just wanted to explore how low the val_loss would go!

if use_wandb: wandb.finish()
learn.save(f'crop-rings-{dataset_name}') # save a checkpoint so we can restart from here later
Path('models/crop-rings-cleaner.pth')

Interpretation

learn.load(f'crop-rings-{dataset_name}'); # can start from here assuming learn, dls, etc are defined
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
print(f"We have {len(preds)} predictions.")
We have 1322 predictions.

Let's define a method to show a single prediction

def showpred(ind, preds, targs, losses, dls): # show prediction at this index
    print(f"preds[{ind}] = {preds[ind]}, targs[{ind}] = {targs[ind]}, loss = {losses[ind]}")
    print(f"file = {os.path.basename(dls.valid.items[ind])}")
    print("Image:")
    dls.valid.dataset[ind][0].show()
showpred(0, preds, targs, losses, dls)
preds[0] = tensor([9.6772]), targs[0] = 11.0, loss = 1.7498469352722168
file = 06240907_proc_01447_0_107_170_318_11.0.png
Image:

And now we'll run through predictions for the whole validation set:

results = []
for i in range(len(preds)):
    line_list = [dls.valid.items[i].stem]+[round(targs[i].cpu().numpy().item(),2), round(preds[i][0].cpu().numpy().item(),2), losses[i].cpu().numpy(), i]
    results.append(line_list)

# store ring counts as as Pandas dataframe
res_df = pd.DataFrame(results, columns=['filename', 'target', 'prediction', 'loss','i'])

There is no fastai top_losses defined for this type, but we can do our own version of printing top_losses:

res_df = res_df.sort_values('loss', ascending=False)
res_df.head()
filename target prediction loss i
322 06241902_proc_00632_0_39_90_238_1.0 1.0 10.91 98.200325 322
644 06241902_proc_01798_0_57_73_244_1.0 1.0 10.69 93.93914 644
883 06240907_proc_00320_0_103_193_332_2.7 2.7 9.30 43.568176 883
1118 06240907_proc_01197_0_105_185_338_3.3 3.3 9.31 36.06226 1118
684 06240907_proc_01021_0_93_183_344_3.8 3.8 9.64 34.15754 684
def show_top_losses(res_df, preds, targs, losses, dls, n=5):
    for j in range(n):
        showpred(res_df.iloc[j]['i'], preds, targs, losses, dls)
        
show_top_losses(res_df, preds, targs, losses, dls)
preds[322] = tensor([10.9096]), targs[322] = 1.0, loss = 98.20032501220703
file = 06241902_proc_00632_0_39_90_238_1.0.png
Image:
preds[644] = tensor([10.6922]), targs[644] = 1.0, loss = 93.93914031982422
file = 06241902_proc_01798_0_57_73_244_1.0.png
Image:
preds[883] = tensor([9.3006]), targs[883] = 2.700000047683716, loss = 43.56817626953125
file = 06240907_proc_00320_0_103_193_332_2.7.png
Image:
preds[1118] = tensor([9.3052]), targs[1118] = 3.299999952316284, loss = 36.062259674072266
file = 06240907_proc_01197_0_105_185_338_3.3.png
Image:
preds[684] = tensor([9.6444]), targs[684] = 3.799999952316284, loss = 34.15753936767578
file = 06240907_proc_01021_0_93_183_344_3.8.png
Image:

So then we can these results output into a CSV file, and use it to direct our data-cleaning efforts, i.e. look at the top-loss images first!

res_df.to_csv(f'ring_count_top_losses_{dataset_name}.csv', index=False)

Explore the Data

Let's take a look at plots of this data

df2 = res_df.reset_index(drop=True)
plt.plot(df2["target"],'o',label='target')
plt.plot(df2["prediction"],'s', label='prediction')
plt.xlabel('Top-loss order (left=worse, right=better)')
plt.legend(loc='lower right')
plt.ylabel('Ring count')
Text(0, 0.5, 'Ring count')
plt.plot(df2["target"],df2["prediction"],'o')
plt.xlabel('Target ring count')
plt.ylabel('Predicted ring count')
plt.axis('square')
(-0.23500000000000004, 11.887000000000002, -0.35100000000000003, 11.771)
print(f"Target ring count range: ({df2['target'].min()}, {df2['target'].max()})")
print(f"Predicted ring count range: ({df2['prediction'].min()}, {df2['prediction'].max()})")
Target ring count range: (0.3, 11.0)
Predicted ring count range: (0.2, 11.22)

Plots for paper

We'll re-do the above plots using the saved CSV file and similar files for the other datasets in order to make a composite plot.