Mod of Zach Muller's WWF 01_Custom.ipynb.

Here we'll take cropped images of antinodes and try to count the rings, by fashioning a regression model out of a one-class classification model and scaling the output sigmoid (via fastai's y_range parameter) so that our fitted values stay within the linear regime of the sigmoid.

And we also want to "clamp" our output between a min of about 0.2 rings and a max of 11 rings, because that's how the dataset was created; so sigmoid makes a good choice for this "clamping" too.

 

Installs & Imports

import espiownage
from espiownage.core import *
sysinfo()
print(f"espiownage version {espiownage.__version__}")
TORCH_VERSION=torch1.9.0; CUDA_VERSION=cu102
CUDA available = True, Device count = 1, Current device = 0
Device name = NVIDIA GeForce RTX 2070 with Max-Q Design
hostname: oryxpro
espiownage version 0.0.47

And import our libraries

from fastai.vision.all import *
from espiownage.core import *

Below you will find the exact imports for everything we use today

from fastcore.foundation import L
from fastcore.xtras import Path # @patch'd properties to the Pathlib module

from fastai.callback.fp16 import to_fp16
from fastai.callback.schedule import fit_one_cycle, lr_find 
from fastai.data.external import untar_data, URLs

from fastai.data.block import RegressionBlock, DataBlock
from fastai.data.transforms import get_image_files, Normalize, RandomSplitter, parent_label

from fastai.interpret import ClassificationInterpretation
from fastai.learner import Learner # imports @patch'd properties to Learner including `save`, `load`, `freeze`, and `unfreeze`
from fastai.optimizer import ranger

from fastai.vision.augment import aug_transforms, RandomResizedCrop, Resize
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.learner import cnn_learner
from fastai.vision.utils import download_images, verify_images

Run parameters

dataset_name = 'spnet' # choose from: 
                            # - cleaner (*real* data that's clean-er than "preclean"), 
                            # - preclean (unedited aggregates of 15-or-more volunteers)
                            # - spnet,   (original SPNet Real dataset)
                            # - cyclegan (original SPNet CGSmall dataset)
                            # - fake (newer than SPNet fake, this includes non-int ring #s)
use_wandb = False        # WandB.ai logging
project = 'count_in_crops' # project name for wandb
if use_wandb: 
    !pip install wandb -qqq
    import wandb
    from fastai.callback.wandb import *
    from fastai.callback.tracker import SaveModelCallback
    wandb.login()
path = get_data(dataset_name) / 'crops'; path
Path('/home/drscotthawley/.espiownage/data/espiownage-spnet/crops')
fnames = get_image_files(path)
print(f"{len(fnames)} total cropped images")
ind = 1  # pick one cropped image
fnames[ind]
2045 total cropped images
Path('/home/drscotthawley/.espiownage/data/espiownage-spnet/crops/06240907_proc_00459_229_6_286_81_1.25.png')

For labels, we want the ring count which is the number between the last '_' and the '.png'

def label_func(x):  
    return round(float(x.stem.split('_')[-1]),2)

print(label_func(fnames[ind]))
1.25
cropsize = (300,300) # pixels
croppedrings = DataBlock(blocks=(ImageBlock, RegressionBlock(n_out=1)),
                    get_items=get_image_files,
                    splitter=RandomSplitter(),
                    get_y=label_func,
                    item_tfms=Resize(cropsize, ResizeMethod.Squish),
                    batch_tfms=[*aug_transforms(size=cropsize, flip_vert=True, max_rotate=360.0), 
                    Normalize.from_stats(*imagenet_stats)])
dls = croppedrings.dataloaders(path, bs=32)
/home/drscotthawley/envs/espi/lib/python3.9/site-packages/torch/_tensor.py:1023: UserWarning: torch.solve is deprecated in favor of torch.linalg.solveand will be removed in a future PyTorch release.
torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  /pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp:760.)
  ret = func(*args, **kwargs)
dls.show_batch(max_n=9)

Train model

opt = ranger
y_range=(0.2,13)  # balance between "clamping" to range of real data vs too much "compression" from sigmoid nonlinearity
learn = cnn_learner(dls, resnet34, n_out=1, y_range=y_range, metrics=[mae, acc_reg05,acc_reg1,acc_reg15,acc_reg2], loss_func=MSELossFlat(), opt_func=opt)
/home/drscotthawley/envs/espi/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
learn.lr_find()
/home/drscotthawley/envs/espi/lib/python3.9/site-packages/fastai/callback/schedule.py:269: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "ro" (-> color='r'). The keyword argument will take precedence.
  ax.plot(val, idx, 'ro', label=nm, c=color)
SuggestedLRs(valley=0.009120108559727669)
lr = 5e-3
learn.fine_tune(30, lr, freeze_epochs=2)  # accidentally ran this twice in a row :-O
epoch train_loss valid_loss mae acc_reg05 acc_reg1 acc_reg15 acc_reg2 time
0 19.282076 9.722394 2.462277 0.095355 0.266504 0.420538 0.533007 00:12
1 16.204414 12.365723 2.698094 0.129584 0.278729 0.403423 0.501222 00:12
epoch train_loss valid_loss mae acc_reg05 acc_reg1 acc_reg15 acc_reg2 time
0 12.168900 9.776813 2.327484 0.176039 0.342298 0.466993 0.572127 00:16
1 11.358937 8.480050 2.271135 0.112469 0.293399 0.437653 0.550122 00:16
2 10.466084 8.561608 2.285184 0.085575 0.283619 0.435208 0.574572 00:16
3 10.245325 4.849682 1.665897 0.205379 0.405868 0.581907 0.679707 00:16
4 9.041600 9.689653 2.364898 0.151589 0.322738 0.449878 0.550122 00:16
5 7.896089 7.220519 2.001585 0.149144 0.386308 0.528117 0.628362 00:16
6 6.740160 3.577236 1.383209 0.281174 0.496333 0.652812 0.765281 00:16
7 5.755038 1.830490 0.979590 0.383863 0.633252 0.787286 0.855746 00:16
8 4.053399 2.248028 1.111390 0.342298 0.562347 0.723716 0.848411 00:16
9 3.193851 2.007382 1.030551 0.374083 0.606357 0.753056 0.863081 00:16
10 2.732466 1.752073 0.953972 0.418093 0.633252 0.782396 0.877751 00:16
11 2.691343 1.916968 0.998160 0.400978 0.608802 0.765281 0.858191 00:16
12 2.656349 1.751836 0.946858 0.396088 0.652812 0.794621 0.882641 00:16
13 2.502453 1.813630 1.018652 0.312958 0.633252 0.789731 0.855746 00:16
14 2.276822 1.644336 0.943956 0.359413 0.679707 0.789731 0.882641 00:16
15 2.002220 2.007890 1.014130 0.388753 0.635697 0.765281 0.858191 00:16
16 2.025363 1.607281 0.936667 0.366748 0.682152 0.799511 0.889976 00:16
17 1.949465 1.485418 0.855687 0.457213 0.687042 0.831296 0.909535 00:16
18 1.927312 1.541676 0.850113 0.457213 0.699266 0.836186 0.907090 00:16
19 1.975950 1.630775 0.904536 0.420538 0.665037 0.823961 0.897310 00:16
20 1.924967 1.713278 0.964017 0.359413 0.665037 0.792176 0.887531 00:16
21 1.782603 1.547323 0.909293 0.376528 0.672372 0.821516 0.902200 00:16
22 1.777079 1.543253 0.886471 0.427873 0.701711 0.823961 0.894866 00:16
23 1.750983 1.572344 0.884398 0.435208 0.684597 0.826406 0.897310 00:16
24 1.715748 1.557991 0.873965 0.432763 0.696822 0.833741 0.904646 00:16
25 1.685339 1.397585 0.813730 0.466993 0.721271 0.850856 0.909535 00:16
26 1.692378 1.431694 0.846405 0.432763 0.723716 0.841076 0.909535 00:16
27 1.612728 1.423968 0.839258 0.444988 0.716381 0.838631 0.907090 00:16
28 1.665606 1.490422 0.862552 0.432763 0.699266 0.838631 0.902200 00:16
29 1.607735 1.403842 0.835166 0.442543 0.723716 0.841076 0.916870 00:16
learn.save(f'crop-rings-{dataset_name}')
Path('models/crop-rings-spnet.pth')

Interpretation

learn.load(f'crop-rings-{dataset_name}');
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
len(preds)
409

I'll define a method to show a single prediction

def showpred(ind, preds, targs, losses, dls): # show prediction at this index
    print(f"preds[{ind}] = {preds[ind]}, targs[{ind}] = {targs[ind]}, loss = {losses[ind]}")
    print(f"file = {os.path.basename(dls.valid.items[ind])}")
    print("Image:")
    dls.valid.dataset[ind][0].show()
showpred(0, preds, targs, losses, dls)
preds[0] = tensor([2.6761]), targs[0] = 2.0, loss = 0.4571508765220642
file = 06240907_proc_00463_211_0_284_87_2.0.png
Image:

And now we'll run through predictions for the whole validation set:

results = []
for i in range(len(preds)):
    line_list = [dls.valid.items[i].stem]+[round(targs[i].cpu().numpy().item(),2), round(preds[i][0].cpu().numpy().item(),2), losses[i].cpu().numpy(), i]
    results.append(line_list)

# store as pandas dataframe
res_df = pd.DataFrame(results, columns=['filename', 'target', 'prediction', 'loss','i'])

We can do our own version of printing top_losses:

res_df = res_df.sort_values('loss', ascending=False)
res_df.head()
filename target prediction loss i
402 06240907_proc_00791_0_105_198_335_2.5 2.50 7.88 28.979721 402
287 06240907_proc_00320_0_103_193_332_2.67 2.67 7.62 24.472479 287
262 06240907_proc_00852_0_127_183_300_2.7 2.70 7.55 23.496952 262
132 06240907_proc_00506_0_115_173_304_10.67 10.67 6.88 14.388533 132
100 06240910_proc_00607_8_50_251_259_5.0 5.00 8.65 13.339838 100
def show_top_losses(res_df, preds, targs, losses, dls, n=5):
    for j in range(n):
        showpred(res_df.iloc[j]['i'], preds, targs, losses, dls)
        
show_top_losses(res_df, preds, targs, losses, dls)
preds[402] = tensor([7.8833]), targs[402] = 2.5, loss = 28.979721069335938
file = 06240907_proc_00791_0_105_198_335_2.5.png
Image:
preds[287] = tensor([7.6170]), targs[287] = 2.6700000762939453, loss = 24.47247886657715
file = 06240907_proc_00320_0_103_193_332_2.67.png
Image:
preds[262] = tensor([7.5474]), targs[262] = 2.700000047683716, loss = 23.496952056884766
file = 06240907_proc_00852_0_127_183_300_2.7.png
Image:
preds[132] = tensor([6.8768]), targs[132] = 10.670000076293945, loss = 14.388532638549805
file = 06240907_proc_00506_0_115_173_304_10.67.png
Image:
preds[100] = tensor([8.6524]), targs[100] = 5.0, loss = 13.339838027954102
file = 06240910_proc_00607_8_50_251_259_5.0.png
Image:

So then we can these results output into a CSV file, and use it to direct our data-cleaning efforts, i.e. look at the top-loss images first!

res_df.to_csv(f'ring_count_top_losses_{dataset_name}.csv', index=False)

When in doubt, look at the data...

Let's take a look at plots of this data

df2 = res_df.reset_index(drop=True)
plt.plot(df2["target"],'o',label='target')
plt.plot(df2["prediction"],'s', label='prediction')
plt.xlabel('Top-loss order (left=worse, right=better)')
plt.legend(loc='lower right')
plt.ylabel('Ring count')
Text(0, 0.5, 'Ring count')
plt.plot(df2["target"],df2["prediction"],'o')
plt.xlabel('Target ring count')
plt.ylabel('Predicted ring count')
plt.axis('square')
(0.15349999999999997, 11.5165, -0.24500000000000005, 11.118000000000002)
print(f"Target ring count range: ({df2['target'].min()}, {df2['target'].max()})")
print(f"Predicted ring count range: ({df2['prediction'].min()}, {df2['prediction'].max()})")
Target ring count range: (0.67, 11.0)
Predicted ring count range: (0.2, 9.1)