fad_score

Produce FAD score based on files of embeddings of real and fake data

\[ FAD = || \mu_r - \mu_f ||^2 + tr\left(\Sigma_r + \Sigma_f - 2 \sqrt{\Sigma_r \Sigma_f}\right)\]

The embeddings are small enough that this can typically be run on a single processor, on a CPU. However, all the supporting code is GPU-friendly if so desired.


source

read_embeddings

 read_embeddings (emb_path='real_emb_clap/', debug=False)

reads any .pt files in emb_path and concatenates them into one tensor

# lil test of that
e = read_embeddings()
e.shape
torch.Size([256, 512])

source

calc_mu_sigma

 calc_mu_sigma (emb)

calculates mean and covariance matrix of batched embeddings

# quick test:
x = torch.rand(32,512) 
mu, sigma = calc_mu_sigma(x) 
mu.shape, sigma.shape
(torch.Size([512]), torch.Size([512, 512]))

source

calc_score

 calc_score (real_emb_path, fake_emb_path, method='maji', debug=False)
Type Default Details
real_emb_path where real embeddings are stored
fake_emb_path where fake embeddings are stored
method str maji sqrtm calc method: ‘maji’|‘li’
debug bool False

Test the score function:

score = calc_score( 'real_emb_clap/', 'fake_emb_clap/', method='maji')
print(score)
Calculating FAD score for files in real_emb_clap// vs. fake_emb_clap//
tensor(0.0951)

Try sending using the exact same data for both distributions: Do we get zero?

score = calc_score( 'real_emb_clap/', 'real_emb_clap/', method='maji', debug=True)
print(score)
Calculating FAD score for files in real_emb_clap// vs. real_emb_clap//
searching in  real_emb_clap/
searching in  real_emb_clap/
torch.Size([256, 512]) torch.Size([256, 512])
mu_real.shape, sigma_real.shape = torch.Size([512]) torch.Size([512, 512])
mu_fake.shape, sigma_fake.shape = torch.Size([512]) torch.Size([512, 512])
mu_diff =  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
score1: mu_diff.dot(mu_diff) =  tensor(0.)
score2: torch.trace(sigma_real) =  tensor(0.4448)
score3: torch.trace(sigma_fake) =  tensor(0.4448)
score_p.shape (matmul) =  torch.Size([512, 512])
score4 (-2*tr(sqrtm(matmul(sigma_r sigma_f))))  =  tensor(-0.8888)
tensor(0.0008)

Ok, so not zero, but small.


source

main

 main ()