pytorch中vgg的预训练模型分类一张自己提供的图片
import torch import numpy import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import transforms import torchvision.models as models vgg = models.vgg16() pre=torch.load('./vgg16-397923af.pth') vgg.load_state_dict(pre) r""" vgg的pretrained模型是在imagenet上预训练的,提供的是一个1000分类的输出,每个类别标签见:https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/c2c91c8e767d04621020c30ed31192724b863041/imagenet1000_clsid_to_human.txt 完美的图片大小是224*224 transform就是三步走 unsqueeze后要加下划线才是原地操作,总是忘记 有了pth文件,要先torch.load一下,后load_state_dict """ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],#这是imagenet std=[0.229, 0.224, 0.225]) tran=transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) im='./1.jpeg' im=Image.open(im) im=tran(im) im.unsqueeze_(dim=0) print(im.shape) # input() out=vgg(im) outnp=out.data[0] ind=int(numpy.argmax(outnp)) print(ind) from cls import d print(d[ind]) print(out.shape) # im.show()