PyTorch-Load Predict
- 载入训练好的模型并且进行预测
代码:
import torch import numpy as np import torchvision #torch的视觉包 import torchvision.datasets as datasets import torchvision.transforms as transforms from torch.utils.data.dataloader import DataLoader from torchvision.transforms import ToTensor import matplotlib.pyplot as plt import PIL.Image as Image import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.datasets as datasets import torchvision.transforms as transforms import torch.optim as optim class CNNModel(nn.Module): def __init__(self): #初始化 super(CNNModel, self).__init__() #调用父类 self.conv1 = nn.Conv2d(1, 20,5) #二维卷积 输入特征的维数是1 1*28*28 输出为20个特征维度 卷积核为5 self.conv2 = nn.Conv2d(20, 12,5) #输入为20 输出为12 12*8*8 self.fc1 = nn.Linear(12*4*4, 100,bias=True) #线性层 self.fc2 = nn.Linear(100, 10,bias=True) #线性层 def forward(self, x): x=x x = self.conv1(x) #前向卷积 x = F.relu(x) x = F.max_pool2d(x,kernel_size=2,stride=2) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x,kernel_size=2,stride=2)#取最大值 12*8*8 变成 12*4*4 x = x.reshape(-1,12*4*4) #转为1维的数据 x = self.fc1(x) x = F.relu(x) x=self.fc2(x) return x cnnmodel=torch.load('D:/Project_Encyclopedia/cnnmodel.pkl')#导入模型 print(cnnmodel) cnnmodel1=torch.load('D:/Project_Encyclopedia/cnnmodel.pt')#导入模型 print(cnnmodel1) print(type(cnnmodel)) print(type(cnnmodel1)) cnnmodel1=CNNModel() cnnmodel1.load_state_dict(torch.load('D:/Project_Encyclopedia/cnnmodel.pt')) print(type(cnnmodel1)) root='D:\Project_Encyclopedia' mnist=torchvision.datasets.MNIST(root,train=False,transform=ToTensor(),target_transform=None,download=False) bs=8 mnist_loader=torch.utils.data.DataLoader(dataset=mnist,batch_size=bs,shuffle=True,pin_memory=True) len(mnist) #现在为测试集 batch=next(iter(mnist_loader)) image,labels=batch out=cnnmodel1(image) out.shape out.argmax(dim=1) labels grid=torchvision.utils.make_grid(image,nrow=8)#创建一个网络 plt.figure(figsize=(15,15)) plt.imshow(np.transpose(grid,(1,2,0))) print("labels:",labels) print('predicts:',out.argmax(dim=1))
转载请注明出处,欢迎讨论和交流!