pytorch中多个loss回传的参数影响示例
写了一段代码如下:
import torch import torch.nn as nn import torch.nn.functional as F class Test(nn.Module): def __init__(self): super(Test, self).__init__() self.fc1 = nn.Linear(5, 4) self.fc2 = nn.Linear(4, 3) self.fc3 = nn.Linear(4, 3) def forward(self, x): mid = self.fc1(x) out1 = self.fc2(mid) out2 = self.fc3(mid) return out1, out2 x = torch.randn((3, 5)) y = torch.torch.randint(3, (3,), dtype=torch.int64) model = Test() model.train() optim = torch.optim.RMSprop(model.parameters(), lr=0.001) print(model.fc2.weight) print(model.fc3.weight) for i in range(5): out1, out2 = model(x) loss1 = F.cross_entropy(out1, y) loss2 = F.cross_entropy(out2, y) loss = loss1 + loss2 optim.zero_grad() loss.backward() optim.step() print("-------------after-----------") print(model.fc2.weight) print(model.fc3.weight)
在loss.backward()处分别更换为loss1.backward()和loss2.backward(),观察fc2和fc3层的参数变化。
得出的结论为:loss2只影响fc3的参数,loss1只影响fc2的参数。
(粗略分析,抛砖引玉)