from torch.autograd import gradcheck
sqrtm
Steve Li’s method
MatrixSquareRoot_li
MatrixSquareRoot_li (*args, **kwargs)
From https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py, which sadly does not install as a package. LICENSE included below Square root of a positive definite matrix.
NOTE: matrix square root is not differentiable for matrices with zero eigenvalues.
Steve Li’s test code for the above:
if use_li:
= torch.randn(1000, 1000).double()
k # Create a positive definite matrix
= (k.t().matmul(k)).requires_grad_()
pd_mat with torch.no_grad():
= sqrtm_li(pd_mat)
sq print("sq =\n",sq)
#print("Running gradcheck...")
#test = gradcheck(sqrtm_li, (pd_mat,))
#print(test)
sq =
tensor([[ 2.6360e+01, -5.2896e-01, 4.6020e-01, ..., 4.6385e-01,
-2.5534e-01, 3.3804e-01],
[-5.2896e-01, 2.5773e+01, -7.2415e-01, ..., -2.6621e-02,
3.0918e-01, -7.8089e-02],
[ 4.6020e-01, -7.2415e-01, 2.5863e+01, ..., -5.2346e-01,
1.4617e-01, -2.5943e-01],
...,
[ 4.6385e-01, -2.6621e-02, -5.2346e-01, ..., 2.6959e+01,
3.6158e-01, -4.6653e-01],
[-2.5534e-01, 3.0918e-01, 1.4617e-01, ..., 3.6158e-01,
2.6692e+01, -4.4417e-01],
[ 3.3804e-01, -7.8089e-02, -2.5943e-01, ..., -4.6653e-01,
-4.4417e-01, 2.8916e+01]], dtype=torch.float64)
Subhransu Maji’s method(s)
From https://github.com/msubhransu/matrix-sqrt
sqrt_newton_schulz
sqrt_newton_schulz (A, numIters=20, calc_error=False)
Sqrt of matrix via Newton-Schulz algorithm Modified from https://github.com/msubhransu/matrix-sqrt/blob/cc2289a3ed7042b8dbacd53ce8a34da1f814ed2f/matrix_sqrt.py#LL72C1-L87C19 # Forward via Newton-Schulz iterations (non autograd version) # Seems to be slighlty faster and has much lower memory overhead
… Original code didn’t preserve device, had no batch dim checking -SHH
Type | Default | Details | |
---|---|---|---|
A | matrix to be sqrt-ified | ||
numIters | int | 20 | numIters=7 found via experimentation |
calc_error | bool | False | setting False disables Maji’s error reporting |
sqrt_newton_schulz_autograd
sqrt_newton_schulz_autograd (A, numIters=20, calc_error=False)
Modified from from https://people.cs.umass.edu/~smaji/projects/matrix-sqrt/ “The drawback of the autograd approach [i.e., this approach] is that a naive implementation stores all the intermediate results. Thus the memory overhead scales linearly with the number of iterations which is problematic for large matrices.”
Type | Default | Details | |
---|---|---|---|
A | |||
numIters | int | 20 | found experimentally by SHH, comparing w/ Li’s method |
calc_error | bool | False |
compute_error
compute_error (A, sA)
Error tests
= sqrt_newton_schulz_autograd( pd_mat.unsqueeze(0), numIters=20, calc_error=True )
sa1, error print("sa1 =\n",sa1)
print("error =",error.detach().item())
sa1 =
tensor([[[ 2.6360e+01, -5.2896e-01, 4.6021e-01, ..., 4.6384e-01,
-2.5535e-01, 3.3805e-01],
[-5.2896e-01, 2.5773e+01, -7.2415e-01, ..., -2.6632e-02,
3.0917e-01, -7.8092e-02],
[ 4.6021e-01, -7.2415e-01, 2.5863e+01, ..., -5.2344e-01,
1.4618e-01, -2.5943e-01],
...,
[ 4.6384e-01, -2.6632e-02, -5.2344e-01, ..., 2.6959e+01,
3.6150e-01, -4.6653e-01],
[-2.5535e-01, 3.0917e-01, 1.4618e-01, ..., 3.6150e-01,
2.6692e+01, -4.4418e-01],
[ 3.3805e-01, -7.8092e-02, -2.5943e-01, ..., -4.6653e-01,
-4.4418e-01, 2.8916e+01]]], dtype=torch.float64,
grad_fn=<MulBackward0>)
error = 4.759428080865442e-08
= sqrt_newton_schulz( pd_mat.unsqueeze(0), numIters=20, calc_error=True )
sa2, error print("sa2 =\n",sa2)
print("error =",error.detach().item())
sa2 =
tensor([[[ 2.6360e+01, -5.2896e-01, 4.6021e-01, ..., 4.6384e-01,
-2.5535e-01, 3.3805e-01],
[-5.2896e-01, 2.5773e+01, -7.2415e-01, ..., -2.6632e-02,
3.0917e-01, -7.8092e-02],
[ 4.6021e-01, -7.2415e-01, 2.5863e+01, ..., -5.2344e-01,
1.4618e-01, -2.5943e-01],
...,
[ 4.6384e-01, -2.6632e-02, -5.2344e-01, ..., 2.6959e+01,
3.6150e-01, -4.6653e-01],
[-2.5535e-01, 3.0917e-01, 1.4618e-01, ..., 3.6150e-01,
2.6692e+01, -4.4418e-01],
[ 3.3805e-01, -7.8092e-02, -2.5943e-01, ..., -4.6653e-01,
-4.4418e-01, 2.8916e+01]]], dtype=torch.float64,
grad_fn=<MulBackward0>)
error = 4.759428080865442e-08
if use_li:
= sa1 - sq
diff print("diff = \n",diff)
diff =
tensor([[[-3.7835e-05, 3.8437e-07, 8.7208e-06, ..., -1.2816e-05,
-1.8744e-05, 1.3386e-05],
[ 3.8437e-07, -1.9896e-06, 2.2977e-06, ..., -1.1004e-05,
-1.0735e-05, -3.3250e-06],
[ 8.7208e-06, 2.2977e-06, -7.5329e-06, ..., 2.2518e-05,
1.9394e-05, -3.9655e-06],
...,
[-1.2816e-05, -1.1004e-05, 2.2518e-05, ..., -8.1416e-05,
-7.1274e-05, -1.4799e-06],
[-1.8744e-05, -1.0735e-05, 1.9394e-05, ..., -7.1274e-05,
-8.0163e-05, -1.6137e-05],
[ 1.3386e-05, -3.3250e-06, -3.9655e-06, ..., -1.4799e-06,
-1.6137e-05, -2.6324e-05]]], dtype=torch.float64,
grad_fn=<SubBackward0>)
if use_li:
= sa2 - sq
diff print("diff = \n",diff)
diff =
tensor([[[-3.7835e-05, 3.8437e-07, 8.7208e-06, ..., -1.2816e-05,
-1.8744e-05, 1.3386e-05],
[ 3.8437e-07, -1.9896e-06, 2.2977e-06, ..., -1.1004e-05,
-1.0735e-05, -3.3250e-06],
[ 8.7208e-06, 2.2977e-06, -7.5329e-06, ..., 2.2518e-05,
1.9394e-05, -3.9655e-06],
...,
[-1.2816e-05, -1.1004e-05, 2.2518e-05, ..., -8.1416e-05,
-7.1274e-05, -1.4799e-06],
[-1.8744e-05, -1.0735e-05, 1.9394e-05, ..., -7.1274e-05,
-8.0163e-05, -1.6137e-05],
[ 1.3386e-05, -3.3250e-06, -3.9655e-06, ..., -1.4799e-06,
-1.6137e-05, -2.6324e-05]]], dtype=torch.float64,
grad_fn=<SubBackward0>)
Speed & device tests
from aeiou.core import get_device
= get_device()
device print('device = ',device)
= 1000, 1000
n,m with torch.no_grad():
= torch.randn(n, m, device=device)
k = (k.t().matmul(k)) # Create a positive definite matrix, no grad pd_mat2
device = cuda
# %%timeit
if use_li:
= sqrtm_li(pd_mat2) sq2
Result of %%timeit
:
1.12 s ± 191 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
print(sq2)
NameError: name 'sq2' is not defined
# %%timeit
= sqrt_newton_schulz(pd_mat2.unsqueeze(0), numIters=20)[0] sq3
Result of %%timeit
:
8.8 ms ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Wrapper around our method of choice:
TLDR, we’ll use Maji’s Newton-Schulz method. Newton-Schulz is an approximate iterative method rather than an exact matrix sqrt, however, with 7 iterations the error is below 1e-5, (presumably significantly) lower than other errors in the problem.
sqrtm
sqrtm (A, method='maji', numIters=20)
wrapper function for matrix sqrt algorithm of choice. Also we’ll turn off all gradients
='maji') - sqrtm(pd_mat2, method='li') sqrtm(pd_mat2, method
tensor([[ 1.9073e-06, 1.5616e-05, 9.0003e-06, ..., -8.9407e-07,
-3.7104e-05, 6.7353e-06],
[ 1.9193e-05, -6.1035e-05, -4.7386e-06, ..., 7.8678e-06,
2.9802e-05, -5.8711e-06],
[ 7.5698e-06, -5.4836e-06, -3.8147e-05, ..., 2.0236e-05,
1.6876e-06, -2.9024e-05],
...,
[-3.3975e-06, 3.6955e-06, 2.2471e-05, ..., -1.1444e-05,
3.6955e-06, 1.2971e-05],
[-3.6355e-05, 2.5883e-05, 7.5437e-06, ..., 9.4771e-06,
-6.8665e-05, 6.7949e-06],
[ 4.2319e-06, -1.3113e-05, -2.6686e-05, ..., 9.9763e-06,
1.1921e-05, 5.7220e-06]], device='cuda:0')