19-神经网络-模型的加载和保存

model.eval()和torch.no_grad()的区别 https://blog.csdn.net/qq_41813454/article/details/135129279

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
x = torch.randn(1, 20)
y = net(x)
print(y)

torch.save(net.state_dict(), 'mlp.params')

clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()

y_clone = clone(x)
print(y_clone == y)
posted @ 2024-08-25 12:08  不是孩子了  阅读(5)  评论(0编辑  收藏  举报