矩阵乘法求导

矩阵乘法求导

pyotrch中只能是标量对矩阵求导,所以矩阵乘法结束后加个sum

\[L = sum(\bm{WX}) \]

其中,\(\bm{W}\)\(\bm{X}\)都是矩阵,那么

\[\frac{\partial L}{\partial\bm{W}}_{\cdot i}=\sum\bm{X}_{i\cdot} \]

梯度和W的形状相同,梯度中每列都是相同的,只要是第i列,梯度值就是\(\bm{X}\)的第i行的和。
用公式不太好表示,我们用pytorch代码来描述一下:

>>> w = torch.arange(0,50,dtype=torch.float32, requires_grad=True)
>>> nw  = w.view(10, 5)
>>> x = torch.arange(50,100, dtype=torch.float32, requires_grad=True)
>>> nx = x.view(5,10)
>>> loss = torch.matmul(nw, nx)
>>> sum_loss = loss.sum()
>>> sum_loss.backward()
>>> print(w.grad.view_as(nw),'\n', nx.detach().sum(dim=1))
tensor([[545., 645., 745., 845., 945.],
        [545., 645., 745., 845., 945.],
        [545., 645., 745., 845., 945.],
        [545., 645., 745., 845., 945.],
        [545., 645., 745., 845., 945.],
        [545., 645., 745., 845., 945.],
        [545., 645., 745., 845., 945.],
        [545., 645., 745., 845., 945.],
        [545., 645., 745., 845., 945.],
        [545., 645., 745., 845., 945.]]) 
 tensor([545., 645., 745., 845., 945.])

这个推导过程也不是很复杂,可以自己举个例子试试。

posted @ 2023-06-14 20:26  王冰冰  阅读(145)  评论(0编辑  收藏  举报