sqrtm

Methods for computing sqrt of a matrix

Steve Li’s method


source

MatrixSquareRoot_li

 MatrixSquareRoot_li (*args, **kwargs)

From https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py, which sadly does not install as a package. LICENSE included below Square root of a positive definite matrix.

NOTE: matrix square root is not differentiable for matrices with zero eigenvalues.

Steve Li’s test code for the above:

from torch.autograd import gradcheck
if use_li:
    k = torch.randn(1000, 1000).double()
    # Create a positive definite matrix
    pd_mat = (k.t().matmul(k)).requires_grad_()
    with torch.no_grad():
        sq = sqrtm_li(pd_mat)
    print("sq =\n",sq)
    #print("Running gradcheck...")
    #test = gradcheck(sqrtm_li, (pd_mat,))
    #print(test)
sq =
 tensor([[ 2.6360e+01, -5.2896e-01,  4.6020e-01,  ...,  4.6385e-01,
         -2.5534e-01,  3.3804e-01],
        [-5.2896e-01,  2.5773e+01, -7.2415e-01,  ..., -2.6621e-02,
          3.0918e-01, -7.8089e-02],
        [ 4.6020e-01, -7.2415e-01,  2.5863e+01,  ..., -5.2346e-01,
          1.4617e-01, -2.5943e-01],
        ...,
        [ 4.6385e-01, -2.6621e-02, -5.2346e-01,  ...,  2.6959e+01,
          3.6158e-01, -4.6653e-01],
        [-2.5534e-01,  3.0918e-01,  1.4617e-01,  ...,  3.6158e-01,
          2.6692e+01, -4.4417e-01],
        [ 3.3804e-01, -7.8089e-02, -2.5943e-01,  ..., -4.6653e-01,
         -4.4417e-01,  2.8916e+01]], dtype=torch.float64)

Subhransu Maji’s method(s)

From https://github.com/msubhransu/matrix-sqrt


source

sqrt_newton_schulz

 sqrt_newton_schulz (A, numIters=20, calc_error=False)

Sqrt of matrix via Newton-Schulz algorithm Modified from https://github.com/msubhransu/matrix-sqrt/blob/cc2289a3ed7042b8dbacd53ce8a34da1f814ed2f/matrix_sqrt.py#LL72C1-L87C19 # Forward via Newton-Schulz iterations (non autograd version) # Seems to be slighlty faster and has much lower memory overhead

… Original code didn’t preserve device, had no batch dim checking -SHH

Type Default Details
A matrix to be sqrt-ified
numIters int 20 numIters=7 found via experimentation
calc_error bool False setting False disables Maji’s error reporting

source

sqrt_newton_schulz_autograd

 sqrt_newton_schulz_autograd (A, numIters=20, calc_error=False)

Modified from from https://people.cs.umass.edu/~smaji/projects/matrix-sqrt/ “The drawback of the autograd approach [i.e., this approach] is that a naive implementation stores all the intermediate results. Thus the memory overhead scales linearly with the number of iterations which is problematic for large matrices.”

Type Default Details
A
numIters int 20 found experimentally by SHH, comparing w/ Li’s method
calc_error bool False

source

compute_error

 compute_error (A, sA)

Error tests

sa1, error = sqrt_newton_schulz_autograd( pd_mat.unsqueeze(0), numIters=20, calc_error=True ) 
print("sa1 =\n",sa1)
print("error =",error.detach().item())
sa1 =
 tensor([[[ 2.6360e+01, -5.2896e-01,  4.6021e-01,  ...,  4.6384e-01,
          -2.5535e-01,  3.3805e-01],
         [-5.2896e-01,  2.5773e+01, -7.2415e-01,  ..., -2.6632e-02,
           3.0917e-01, -7.8092e-02],
         [ 4.6021e-01, -7.2415e-01,  2.5863e+01,  ..., -5.2344e-01,
           1.4618e-01, -2.5943e-01],
         ...,
         [ 4.6384e-01, -2.6632e-02, -5.2344e-01,  ...,  2.6959e+01,
           3.6150e-01, -4.6653e-01],
         [-2.5535e-01,  3.0917e-01,  1.4618e-01,  ...,  3.6150e-01,
           2.6692e+01, -4.4418e-01],
         [ 3.3805e-01, -7.8092e-02, -2.5943e-01,  ..., -4.6653e-01,
          -4.4418e-01,  2.8916e+01]]], dtype=torch.float64,
       grad_fn=<MulBackward0>)
error = 4.759428080865442e-08
sa2, error = sqrt_newton_schulz( pd_mat.unsqueeze(0), numIters=20, calc_error=True ) 
print("sa2 =\n",sa2)
print("error =",error.detach().item())
sa2 =
 tensor([[[ 2.6360e+01, -5.2896e-01,  4.6021e-01,  ...,  4.6384e-01,
          -2.5535e-01,  3.3805e-01],
         [-5.2896e-01,  2.5773e+01, -7.2415e-01,  ..., -2.6632e-02,
           3.0917e-01, -7.8092e-02],
         [ 4.6021e-01, -7.2415e-01,  2.5863e+01,  ..., -5.2344e-01,
           1.4618e-01, -2.5943e-01],
         ...,
         [ 4.6384e-01, -2.6632e-02, -5.2344e-01,  ...,  2.6959e+01,
           3.6150e-01, -4.6653e-01],
         [-2.5535e-01,  3.0917e-01,  1.4618e-01,  ...,  3.6150e-01,
           2.6692e+01, -4.4418e-01],
         [ 3.3805e-01, -7.8092e-02, -2.5943e-01,  ..., -4.6653e-01,
          -4.4418e-01,  2.8916e+01]]], dtype=torch.float64,
       grad_fn=<MulBackward0>)
error = 4.759428080865442e-08
if use_li:
    diff = sa1 - sq
    print("diff = \n",diff)
diff = 
 tensor([[[-3.7835e-05,  3.8437e-07,  8.7208e-06,  ..., -1.2816e-05,
          -1.8744e-05,  1.3386e-05],
         [ 3.8437e-07, -1.9896e-06,  2.2977e-06,  ..., -1.1004e-05,
          -1.0735e-05, -3.3250e-06],
         [ 8.7208e-06,  2.2977e-06, -7.5329e-06,  ...,  2.2518e-05,
           1.9394e-05, -3.9655e-06],
         ...,
         [-1.2816e-05, -1.1004e-05,  2.2518e-05,  ..., -8.1416e-05,
          -7.1274e-05, -1.4799e-06],
         [-1.8744e-05, -1.0735e-05,  1.9394e-05,  ..., -7.1274e-05,
          -8.0163e-05, -1.6137e-05],
         [ 1.3386e-05, -3.3250e-06, -3.9655e-06,  ..., -1.4799e-06,
          -1.6137e-05, -2.6324e-05]]], dtype=torch.float64,
       grad_fn=<SubBackward0>)
if use_li:
    diff = sa2 - sq
    print("diff = \n",diff)
diff = 
 tensor([[[-3.7835e-05,  3.8437e-07,  8.7208e-06,  ..., -1.2816e-05,
          -1.8744e-05,  1.3386e-05],
         [ 3.8437e-07, -1.9896e-06,  2.2977e-06,  ..., -1.1004e-05,
          -1.0735e-05, -3.3250e-06],
         [ 8.7208e-06,  2.2977e-06, -7.5329e-06,  ...,  2.2518e-05,
           1.9394e-05, -3.9655e-06],
         ...,
         [-1.2816e-05, -1.1004e-05,  2.2518e-05,  ..., -8.1416e-05,
          -7.1274e-05, -1.4799e-06],
         [-1.8744e-05, -1.0735e-05,  1.9394e-05,  ..., -7.1274e-05,
          -8.0163e-05, -1.6137e-05],
         [ 1.3386e-05, -3.3250e-06, -3.9655e-06,  ..., -1.4799e-06,
          -1.6137e-05, -2.6324e-05]]], dtype=torch.float64,
       grad_fn=<SubBackward0>)

Speed & device tests

from aeiou.core import get_device
device = get_device()
print('device = ',device)
n,m = 1000, 1000
with torch.no_grad(): 
    k = torch.randn(n, m, device=device)
    pd_mat2 = (k.t().matmul(k)) # Create a positive definite matrix, no grad
device =  cuda
# %%timeit
if use_li:
    sq2 = sqrtm_li(pd_mat2)

Result of %%timeit:

1.12 s ± 191 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

print(sq2)
NameError: name 'sq2' is not defined
# %%timeit
sq3 = sqrt_newton_schulz(pd_mat2.unsqueeze(0), numIters=20)[0]

Result of %%timeit:

8.8 ms ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Wrapper around our method of choice:

TLDR, we’ll use Maji’s Newton-Schulz method. Newton-Schulz is an approximate iterative method rather than an exact matrix sqrt, however, with 7 iterations the error is below 1e-5, (presumably significantly) lower than other errors in the problem.


source

sqrtm

 sqrtm (A, method='maji', numIters=20)

wrapper function for matrix sqrt algorithm of choice. Also we’ll turn off all gradients

sqrtm(pd_mat2, method='maji') - sqrtm(pd_mat2, method='li')
tensor([[ 1.9073e-06,  1.5616e-05,  9.0003e-06,  ..., -8.9407e-07,
         -3.7104e-05,  6.7353e-06],
        [ 1.9193e-05, -6.1035e-05, -4.7386e-06,  ...,  7.8678e-06,
          2.9802e-05, -5.8711e-06],
        [ 7.5698e-06, -5.4836e-06, -3.8147e-05,  ...,  2.0236e-05,
          1.6876e-06, -2.9024e-05],
        ...,
        [-3.3975e-06,  3.6955e-06,  2.2471e-05,  ..., -1.1444e-05,
          3.6955e-06,  1.2971e-05],
        [-3.6355e-05,  2.5883e-05,  7.5437e-06,  ...,  9.4771e-06,
         -6.8665e-05,  6.7949e-06],
        [ 4.2319e-06, -1.3113e-05, -2.6686e-05,  ...,  9.9763e-06,
          1.1921e-05,  5.7220e-06]], device='cuda:0')