Mod of Zach Muller's WWF 01_Custom.ipynb.
Here we'll take cropped images of antinodes and try to count the rings.
Note: This is just a stripped-down, streamlined version of the longer Counting Rings in Cropped Images notebook, with a k-fold cross-validation added in. See that other notebook for a better explanation of what we’re doing.
Note: The WandB links will 404, because there is no "drscotthawley" WandB account. We just used
sed
to replace the real username in the .ipynb files.
import espiownage
from espiownage.core import *
sysinfo()
print(f"espiownage version {espiownage.__version__}")
from fastai.vision.all import *
from fastcore.foundation import L
from fastcore.xtras import Path # @patch'd properties to the Pathlib module
from fastai.callback.fp16 import to_fp16
from fastai.callback.schedule import fit_one_cycle, lr_find
from fastai.data.external import untar_data, URLs
from fastai.data.block import RegressionBlock, DataBlock
from fastai.data.transforms import get_image_files, Normalize, RandomSplitter, parent_label
from fastai.interpret import ClassificationInterpretation
from fastai.learner import Learner # imports @patch'd properties to Learner including `save`, `load`, `freeze`, and `unfreeze`
from fastai.optimizer import ranger
from fastai.vision.augment import aug_transforms, RandomResizedCrop, Resize, ResizeMethod
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.learner import cnn_learner
from fastai.vision.utils import download_images, verify_images
from mrspuff.utils import on_colab
dataset_name = 'cleaner' # choose from: cleaner(=real), preclean, spnet, cyclegan, fake
path = get_data(dataset_name) / 'crops'
print(path)
fnames = get_image_files(path)
print(f"{len(fnames)} total cropped images")
ind = 1 # pick one cropped image
fnames[ind]
def label_func(x):
return round(float(x.stem.split('_')[-1]),2)
print(label_func(fnames[ind]))
cropsize = (300,300) # pixels
!pip install wandb -qqq
import wandb
from fastai.callback.wandb import *
wandb.login()
kfold = True
k = 0 # set k = 0 to 4 & re-run everything from here down
nk = 5
nv = int(len(fnames)/nk) # size of val set
bgn = k*nv # ind to start val set
inds = list(range(bgn, bgn+nv)) # indices for this val set
croppedrings = DataBlock(blocks=(ImageBlock, RegressionBlock(n_out=1)),
get_items=get_image_files,
splitter=IndexSplitter(inds),
get_y=label_func,
item_tfms=Resize(cropsize, ResizeMethod.Squish),
batch_tfms=[*aug_transforms(size=cropsize, flip_vert=True, max_rotate=360.0),
Normalize.from_stats(*imagenet_stats)])
dls = croppedrings.dataloaders(path, bs=32)
def acc_reg07(inp, targ): return acc_reg(inp, targ, bin_size=0.7) # add one more
opt = ranger
y_range=(0.2,13) # balance between "clamping" to range of real data vs too much "compression" from sigmoid nonlinearity
wandb.init(project='ringcounts_kfold', name=f'k={k},{dataset_name}')
learn = cnn_learner(dls, resnet34, n_out=1, y_range=y_range,
metrics=[mae, acc_reg05,acc_reg07, acc_reg1,acc_reg15,acc_reg2], loss_func=MSELossFlat(),
opt_func=opt, cbs=WandbCallback())
#learn.lr_find()
#learn.fine_tune(10, 1e-2)
lr = 5e-3
epochs = 30
freeze_epochs = 2
print(f"Training for {epochs} epochs, with {freeze_epochs} frozen epochs first")
learn.fine_tune(epochs, lr, freeze_epochs=2)
wandb.finish()
learn.save(f'crop-rings-real_k{k}')
learn.load(f'crop-rings-real_k{k}');
preds, targs, losses = learn.get_preds(with_loss=True) # validation set only
print("len(preds = ",len(preds))
def showpred(ind, preds, targs, losses, dls): # show prediction at this index
print(f"preds[{ind}] = {preds[ind]}, targs[{ind}] = {targs[ind]}, loss = {losses[ind]}")
print(f"file = {dls.valid.items[ind]}")
print("Image:")
dls.valid.dataset[ind][0].show()
#showpred(0, preds, targs, losses, dls)
#preds, targs, losses = learn.get_preds(with_loss=True)
results = []
for i in range(len(preds)):
line_list = [dls.valid.items[i].stem]+[round(targs[i].cpu().numpy().item(),2), round(preds[i][0].cpu().numpy().item(),2), losses[i].cpu().numpy(), i]
results.append(line_list)
# store as pandas dataframe
res_df = pd.DataFrame(results, columns=['filename', 'target', 'prediction', 'loss','i'])
res_df = res_df.sort_values('loss', ascending=False)
res_df.head()
res_df.to_csv(f'ring_count_top_losses_real_k{k}.csv', index=False)