In [1]:
import ctf
glob_comm = ctf.comm()

Construct 1D Poisson stiffness matrix $A$.

In [2]:
n = 3
A = (-2.)*ctf.eye(n,n,sp=True) + ctf.eye(n,n,1,sp=True) + ctf.eye(n,n,-1,sp=True)
A
Out[2]:
array([[-2.,  1.,  0.],
       [ 1., -2.,  1.],
       [ 0.,  1., -2.]])

Construct 3D Poisson stiffness matrix $T = A \otimes I \otimes I + I \otimes A \otimes I + I \otimes I \otimes A$ as an order 6 tensor.

In [3]:
I = ctf.eye(n,n,sp=True) # sparse identity matrix
T = ctf.tensor((n,n,n,n,n,n),sp=True) # sparse tensor
T.i("aixbjy") << A.i("ab")*I.i("ij")*I.i("xy") + I.i("ab")*A.i("ij")*I.i("xy") + I.i("ab")*I.i("ij")*A.i("xy")

The 3D Poisson stiffness matrix is full rank.

In [4]:
[U,S,V] = ctf.svd(T.reshape((n*n*n,n*n*n)))
S
Out[4]:
array([10.24264069,  8.82842712,  8.82842712,  8.82842712,  7.41421356,
        7.41421356,  7.41421356,  7.41421356,  7.41421356,  7.41421356,
        6.        ,  6.        ,  6.        ,  6.        ,  6.        ,
        6.        ,  6.        ,  4.58578644,  4.58578644,  4.58578644,
        4.58578644,  4.58578644,  4.58578644,  3.17157288,  3.17157288,
        3.17157288,  1.75735931])

However, if we transpose the tensor modes, the Kronecker product gives a rank-2 form.

In [5]:
T2 = ctf.tensor((n,n,n,n,n,n),sp=True)
T2.i("abijxy") << T.i("aixbjy") # transpose tensor
[U,S,V] = ctf.svd(T2.reshape((n*n, n*n*n*n)),2) # compute rank-2 SVD on unfolded tensor
print(ctf.vecnorm(T2.reshape((n*n, n*n*n*n))-U@ctf.diag(S,sp=True)@V)) # compute norm of error
6.717730079336022e-14

In fact, there are two low-rank matrix unfoldings.

In [6]:
[U,S,V] = ctf.svd(T2.reshape((n*n*n*n, n*n)),2) # compute rank-2 SVD on unfolded tensor
print(ctf.vecnorm(T2.reshape((n*n*n*n, n*n))-U@ctf.diag(S,sp=True)@V)) # compute norm of error
7.860569145580134e-14

We can construct a tensor train factorization to exploit both unfoldings. The tensor train ranks are $2\times 2$.

In [7]:
[U1,S1,V1] = ctf.svd(T2.reshape((n*n, n*n*n*n)),2) # compute rank-2 SVD on unfolded tensor
[U2,S2,V2] = ctf.svd((ctf.diag(S1,sp=True) @ V1).reshape((2*n*n, n*n)),2)
V2 = ctf.diag(S2,sp=True) @ V2
W1 = U1.reshape((n,n,2))
W2 = U2.reshape((2,n,n,2))
W3 = V2.reshape((2,n,n))

The tensor train factorization requires $O(n^2)$ storage for this tensor, which is $n\times n\times n\times n\times n\times n$ and has $O(n^3)$ nonzeros.

In [8]:
E = ctf.tensor((n,n,n,n,n,n))
E.i("aixbjy") << T.i("aixbjy") - W1.i("abu")*W2.i("uijv")*W3.i("vxy")
ctf.vecnorm(E)
Out[8]:
4.549873484854002e-14

The CP decomposition of this tensor should be rank 2 and provides further compression.

In [9]:
from ctf import random
ctf.random.seed(42)
Z1 = ctf.random.random((n,n,2))
Z2 = ctf.random.random((n,n,2))
Z3 = ctf.random.random((n,n,2))
lmbda = ctf.random.random((2))

niter = 0

def normalize(Z):
    norms = ctf.tensor(2)
    norms.i("u") << Z.i("pqu")*Z.i("pqu")
    norms = 1./norms**.5
    X = ctf.tensor(copy=Z)
    Z.set_zero()
    Z.i("pqu") << X.i("pqu")*norms.i("u")
    return 1./norms

normalize(Z1)
normalize(Z2)
normalize(Z3)

E = ctf.tensor((n,n,n,n,n,n))
E.i("aixbjy") << T.i("aixbjy") - lmbda.i("u")*Z1.i("abu")*Z2.i("iju")*Z3.i("xyu")
err_norm = ctf.vecnorm(E)

while (err_norm > 1.e-6 and niter < 100):
    if niter % 10 == 0:
        if glob_comm.rank() == 0:
            print(err_norm)
    M = ctf.tensor((n,n,n,n,2))
    M.i("ijxyu") << Z2.i("iju")*Z3.i("xyu")
    [U,S,V] = ctf.svd(M.reshape((n*n*n*n,2)),2)
    S = 1./S
    Z1.set_zero()
    Z1.i("abu") << V.i("vu")*S.i("v")*U.reshape((n,n,n,n,2)).i("ijxyv")*T.i("aixbjy")
    
    normalize(Z1)
    
    M.set_zero()
    M.i("abxyu") << Z1.i("abu")*Z3.i("xyu")
    [U,S,V] = ctf.svd(M.reshape((n*n*n*n,2)),2)
    S = 1./S
    Z2.set_zero()
    Z2.i("iju") << V.i("vu")*S.i("v")*U.reshape((n,n,n,n,2)).i("abxyv")*T.i("aixbjy")
    
    normalize(Z2)
    
    M.set_zero()
    M.i("abiju") << Z1.i("abu")*Z2.i("iju")
    [U,S,V] = ctf.svd(M.reshape((n*n*n*n,2)),2)
    S = 1./S
    Z3.set_zero()
    Z3.i("xyu") << V.i("vu")*S.i("v")*U.reshape((n,n,n,n,2)).i("abijv")*T.i("aixbjy")

    lmbda = normalize(Z3)
    
    E.set_zero()
    E.i("aixbjy") << T.i("aixbjy") - lmbda.i("u")*Z1.i("abu")*Z2.i("iju")*Z3.i("xyu")
    err_norm = ctf.vecnorm(E)
    
    niter+=1

E.i("aixbjy") << T.i("aixbjy") - lmbda.i("u")*Z1.i("abu")*Z2.i("iju")*Z3.i("xyu")
32.92675598008914
1.1438230719805407
1.1324005982389682
1.1175385983710975
1.0820036880901653
0.7156036446950024
0.05208439321884081
0.051909726375298185
0.05173692237121425
0.05156589242554202