pytorch基础教程2
1. 四部曲
1)forward; 2) 计算误差 ;3)backward; 4) 更新
eg:
1)outputs = net(inputs)
2)loss = criterion(outputs, labels)
3)loss.backward()
4)optimizer.step()
其中,每步关键
1)定义网络
2)定义loss: criterion = nn.CrossEntropyLoss()
3)自动求导
4) 定义优化方法: optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
由此,麻烦的是1),2)
2.