cholesky分解和cholesky求逆

对于正定对称矩阵\(\mathbf{H}\),可以分解为\(\mathbf{H}=\mathbf{XX}^T\),其中\(\mathbf{X}\)是下三角矩阵。这个分解方法就是cholesky分解,pytorch对应的函数是torch.linalg.cholesky

使用\(\mathbf{X}\)可以求出\(\mathbf{H}^{-1}\),pytorch对应的函数是torch.cholesky_inverse

计算cholesky分解,举个官方文档中的例子:

In [48]: a = torch.randn(3, 3)

In [49]: a = torch.mm(a, a.t()) + 1e-05 * torch.eye(3)

In [50]: a
Out[50]:
tensor([[ 1.4079, -0.1658, -0.2116],
        [-0.1658,  3.1347,  0.9066],
        [-0.2116,  0.9066,  0.4598]])

In [51]: u = torch.linalg.cholesky(a)

In [52]: u
Out[52]: 
tensor([[ 1.1865,  0.0000,  0.0000],
        [-0.1398,  1.7650,  0.0000],
        [-0.1783,  0.4995,  0.4224]])

In [53]: torch.mm(u, u.t())
Out[53]: 
tensor([[ 1.4079, -0.1658, -0.2116],
        [-0.1658,  3.1347,  0.9066],
        [-0.2116,  0.9066,  0.4598]])

In [54]: v = torch.linalg.cholesky(a, upper=True)

In [55]: v
Out[55]: 
tensor([[ 1.1865, -0.1398, -0.1783],
        [ 0.0000,  1.7650,  0.4995],
        [ 0.0000,  0.0000,  0.4224]])

In [56]: torch.mm(v.t(), v)
Out[56]: 
tensor([[ 1.4079, -0.1658, -0.2116],
        [-0.1658,  3.1347,  0.9066],
        [-0.2116,  0.9066,  0.4598]])

In [57]:

可以看到,a是个正定对称矩阵,分解可以得到下三角矩阵\(u\)或者上三角矩阵\(v\),使得\(uu^T=a\)或者\(v^Tv=a\)
实际上,观察可以知道,\(v^T=u\)

再来计算逆

In [77]: a_inverse = torch.cholesky_inverse(u)

In [78]: a_inverse
Out[78]: 
tensor([[ 0.7914, -0.1477,  0.6553],
        [-0.1477,  0.7699, -1.5859],
        [ 0.6553, -1.5859,  5.6035]])

In [79]: a.inverse()
Out[79]: 
tensor([[ 0.7914, -0.1477,  0.6553],
        [-0.1477,  0.7699, -1.5859],
        [ 0.6553, -1.5859,  5.6035]])

验证一下是不是逆矩阵:

In [81]: a @ a.inverse()
Out[81]: 
tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00],
        [8.9407e-08, 1.0000e+00, 4.7684e-07],
        [0.0000e+00, 5.9605e-08, 1.0000e+00]])

In [82]: a @ a_inverse
Out[82]: 
tensor([[ 1.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 5.9605e-08,  1.0000e+00,  4.7684e-07],
        [-1.4901e-08,  1.1921e-07,  1.0000e+00]])

In [83]: (a @ a_inverse).int()
Out[83]: 
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]], dtype=torch.int32)

这里@是矩阵乘法运算符,起始得到的是单位矩阵,因为浮点数表示精度的问题,有的数字不是确切为0,而是接近0的一个很小的数字,转换为int可以更清晰地验证。

posted @ 2023-07-07 17:05  王冰冰  阅读(578)  评论(0编辑  收藏  举报