
Methods for computing sqrt of a matrix

Steve Li’s method



 MatrixSquareRoot_li (*args, **kwargs)

From https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py, which sadly does not install as a package. LICENSE included below Square root of a positive definite matrix.

NOTE: matrix square root is not differentiable for matrices with zero eigenvalues.

Steve Li’s test code for the above:

from torch.autograd import gradcheck
if use_li:
    k = torch.randn(1000, 1000).double()
    # Create a positive definite matrix
    pd_mat = (k.t().matmul(k)).requires_grad_()
    with torch.no_grad():
        sq = sqrtm_li(pd_mat)
    print("sq =\n",sq)
    #print("Running gradcheck...")
    #test = gradcheck(sqrtm_li, (pd_mat,))
sq =
 tensor([[ 2.6360e+01, -5.2896e-01,  4.6020e-01,  ...,  4.6385e-01,
         -2.5534e-01,  3.3804e-01],
        [-5.2896e-01,  2.5773e+01, -7.2415e-01,  ..., -2.6621e-02,
          3.0918e-01, -7.8089e-02],
        [ 4.6020e-01, -7.2415e-01,  2.5863e+01,  ..., -5.2346e-01,
          1.4617e-01, -2.5943e-01],
        [ 4.6385e-01, -2.6621e-02, -5.2346e-01,  ...,  2.6959e+01,
          3.6158e-01, -4.6653e-01],
        [-2.5534e-01,  3.0918e-01,  1.4617e-01,  ...,  3.6158e-01,
          2.6692e+01, -4.4417e-01],
        [ 3.3804e-01, -7.8089e-02, -2.5943e-01,  ..., -4.6653e-01,
         -4.4417e-01,  2.8916e+01]], dtype=torch.float64)

Subhransu Maji’s method(s)

From https://github.com/msubhransu/matrix-sqrt



 sqrt_newton_schulz (A, numIters=20, calc_error=False)

Sqrt of matrix via Newton-Schulz algorithm Modified from https://github.com/msubhransu/matrix-sqrt/blob/cc2289a3ed7042b8dbacd53ce8a34da1f814ed2f/matrix_sqrt.py#LL72C1-L87C19 # Forward via Newton-Schulz iterations (non autograd version) # Seems to be slighlty faster and has much lower memory overhead

… Original code didn’t preserve device, had no batch dim checking -SHH

Type Default Details
A matrix to be sqrt-ified
numIters int 20 numIters=7 found via experimentation
calc_error bool False setting False disables Maji’s error reporting



 sqrt_newton_schulz_autograd (A, numIters=20, calc_error=False)

Modified from from https://people.cs.umass.edu/~smaji/projects/matrix-sqrt/ “The drawback of the autograd approach [i.e., this approach] is that a naive implementation stores all the intermediate results. Thus the memory overhead scales linearly with the number of iterations which is problematic for large matrices.”

Type Default Details
numIters int 20 found experimentally by SHH, comparing w/ Li’s method
calc_error bool False



 compute_error (A, sA)

Error tests

sa1, error = sqrt_newton_schulz_autograd( pd_mat.unsqueeze(0), numIters=20, calc_error=True ) 
print("sa1 =\n",sa1)
print("error =",error.detach().item())
sa1 =
 tensor([[[ 2.6360e+01, -5.2896e-01,  4.6021e-01,  ...,  4.6384e-01,
          -2.5535e-01,  3.3805e-01],
         [-5.2896e-01,  2.5773e+01, -7.2415e-01,  ..., -2.6632e-02,
           3.0917e-01, -7.8092e-02],
         [ 4.6021e-01, -7.2415e-01,  2.5863e+01,  ..., -5.2344e-01,
           1.4618e-01, -2.5943e-01],
         [ 4.6384e-01, -2.6632e-02, -5.2344e-01,  ...,  2.6959e+01,
           3.6150e-01, -4.6653e-01],
         [-2.5535e-01,  3.0917e-01,  1.4618e-01,  ...,  3.6150e-01,
           2.6692e+01, -4.4418e-01],
         [ 3.3805e-01, -7.8092e-02, -2.5943e-01,  ..., -4.6653e-01,
          -4.4418e-01,  2.8916e+01]]], dtype=torch.float64,
error = 4.759428080865442e-08
sa2, error = sqrt_newton_schulz( pd_mat.unsqueeze(0), numIters=20, calc_error=True ) 
print("sa2 =\n",sa2)
print("error =",error.detach().item())
sa2 =
 tensor([[[ 2.6360e+01, -5.2896e-01,  4.6021e-01,  ...,  4.6384e-01,
          -2.5535e-01,  3.3805e-01],
         [-5.2896e-01,  2.5773e+01, -7.2415e-01,  ..., -2.6632e-02,
           3.0917e-01, -7.8092e-02],
         [ 4.6021e-01, -7.2415e-01,  2.5863e+01,  ..., -5.2344e-01,
           1.4618e-01, -2.5943e-01],
         [ 4.6384e-01, -2.6632e-02, -5.2344e-01,  ...,  2.6959e+01,
           3.6150e-01, -4.6653e-01],
         [-2.5535e-01,  3.0917e-01,  1.4618e-01,  ...,  3.6150e-01,
           2.6692e+01, -4.4418e-01],
         [ 3.3805e-01, -7.8092e-02, -2.5943e-01,  ..., -4.6653e-01,
          -4.4418e-01,  2.8916e+01]]], dtype=torch.float64,
error = 4.759428080865442e-08
if use_li:
    diff = sa1 - sq
    print("diff = \n",diff)
diff = 
 tensor([[[-3.7835e-05,  3.8437e-07,  8.7208e-06,  ..., -1.2816e-05,
          -1.8744e-05,  1.3386e-05],
         [ 3.8437e-07, -1.9896e-06,  2.2977e-06,  ..., -1.1004e-05,
          -1.0735e-05, -3.3250e-06],
         [ 8.7208e-06,  2.2977e-06, -7.5329e-06,  ...,  2.2518e-05,
           1.9394e-05, -3.9655e-06],
         [-1.2816e-05, -1.1004e-05,  2.2518e-05,  ..., -8.1416e-05,
          -7.1274e-05, -1.4799e-06],
         [-1.8744e-05, -1.0735e-05,  1.9394e-05,  ..., -7.1274e-05,
          -8.0163e-05, -1.6137e-05],
         [ 1.3386e-05, -3.3250e-06, -3.9655e-06,  ..., -1.4799e-06,
          -1.6137e-05, -2.6324e-05]]], dtype=torch.float64,
if use_li:
    diff = sa2 - sq
    print("diff = \n",diff)
diff = 
 tensor([[[-3.7835e-05,  3.8437e-07,  8.7208e-06,  ..., -1.2816e-05,
          -1.8744e-05,  1.3386e-05],
         [ 3.8437e-07, -1.9896e-06,  2.2977e-06,  ..., -1.1004e-05,
          -1.0735e-05, -3.3250e-06],
         [ 8.7208e-06,  2.2977e-06, -7.5329e-06,  ...,  2.2518e-05,
           1.9394e-05, -3.9655e-06],
         [-1.2816e-05, -1.1004e-05,  2.2518e-05,  ..., -8.1416e-05,
          -7.1274e-05, -1.4799e-06],
         [-1.8744e-05, -1.0735e-05,  1.9394e-05,  ..., -7.1274e-05,
          -8.0163e-05, -1.6137e-05],
         [ 1.3386e-05, -3.3250e-06, -3.9655e-06,  ..., -1.4799e-06,
          -1.6137e-05, -2.6324e-05]]], dtype=torch.float64,

Speed & device tests

from aeiou.core import get_device
device = get_device()
print('device = ',device)
n,m = 1000, 1000
with torch.no_grad(): 
    k = torch.randn(n, m, device=device)
    pd_mat2 = (k.t().matmul(k)) # Create a positive definite matrix, no grad
device =  cuda
# %%timeit
if use_li:
    sq2 = sqrtm_li(pd_mat2)

Result of %%timeit:

1.12 s ± 191 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

NameError: name 'sq2' is not defined
# %%timeit
sq3 = sqrt_newton_schulz(pd_mat2.unsqueeze(0), numIters=20)[0]

Result of %%timeit:

8.8 ms ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Wrapper around our method of choice:

TLDR, we’ll use Maji’s Newton-Schulz method. Newton-Schulz is an approximate iterative method rather than an exact matrix sqrt, however, with 7 iterations the error is below 1e-5, (presumably significantly) lower than other errors in the problem.



 sqrtm (A, method='maji', numIters=20)

wrapper function for matrix sqrt algorithm of choice. Also we’ll turn off all gradients

sqrtm(pd_mat2, method='maji') - sqrtm(pd_mat2, method='li')
tensor([[ 1.9073e-06,  1.5616e-05,  9.0003e-06,  ..., -8.9407e-07,
         -3.7104e-05,  6.7353e-06],
        [ 1.9193e-05, -6.1035e-05, -4.7386e-06,  ...,  7.8678e-06,
          2.9802e-05, -5.8711e-06],
        [ 7.5698e-06, -5.4836e-06, -3.8147e-05,  ...,  2.0236e-05,
          1.6876e-06, -2.9024e-05],
        [-3.3975e-06,  3.6955e-06,  2.2471e-05,  ..., -1.1444e-05,
          3.6955e-06,  1.2971e-05],
        [-3.6355e-05,  2.5883e-05,  7.5437e-06,  ...,  9.4771e-06,
         -6.8665e-05,  6.7949e-06],
        [ 4.2319e-06, -1.3113e-05, -2.6686e-05,  ...,  9.9763e-06,
          1.1921e-05,  5.7220e-06]], device='cuda:0')