import torch
import torchvision.models
from PIL import ImageTk, Image
from torchvision import transforms
from torch.autograd import Variable
import torch.nn as nn
import tkinter.filedialog
import tkinter as tk
import windnd
root = None
def selectPath():
path_ = tk.filedialog.askopenfilename()
do(path_)
def next():
root.destroy()
go()
def do(path):
device = torch.device('cuda:0')
model = my_resnet50().to(device)
model.load_state_dict(torch.load("8.pth"), False)
img = Image.open(path)
imgg = img.resize((300, 300))
photo = ImageTk.PhotoImage(imgg)
ph = tk.Label(root, image=photo)
ph.place(relx=0.25, rely=0.1)
# root.mainloop()
img = img.convert('RGB')
val = transforms.Compose([transforms.Resize([224, 224]), transforms.ToTensor()])
img_tensor = val(img)
img_tensor = Variable(torch.unsqueeze(img_tensor, dim=0).float(), requires_grad=False).to(device)
model.eval()
with torch.no_grad():
output = model(img_tensor)
o = torch.softmax(output, dim=1)
pred_val, pred_index = torch.max(o, 1)
pred_val = pred_val.detach().cpu().numpy()
pred_index = pred_index.detach().cpu().numpy()
classes = ["新冠肺炎", "正常肺部", "病毒性肺炎"]
test = tk.Label(root,
text="预测结果为:" + classes[pred_index[0]] + "\n" + "概率为:" + str(pred_val[0] * 100) + "%",
font=("黑体", 20))
test.place(relx=0.2, rely=0.7)
tb = tk.Button(root, text='返回', command=next)
tb.place(relx=0.4, rely=0.8, relwidth=0.2, relheight=0.1)
root.mainloop()
class my_resnet50(nn.Module):
def __init__(self):
super(my_resnet50, self).__init__()
self.backbone = torchvision.models.resnet50(pretrained=False)
self.fc2 = nn.Linear(1000, 512)
self.fc3 = nn.Linear(512, 3)
def forward(self, x):
x = self.backbone(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def go():
global root
root = tk.Tk()
root.title('AI肺炎类型诊断')
root.geometry('600x600')
lb = tk.Label(root, text='请选择待诊断的图片', font=("黑体", 20), pady=20)
lb.pack()
tb = tk.Button(root, text='选择图片', command=selectPath)
tb.place(relx=0.4, rely=0.1, relwidth=0.2, relheight=0.1)
root.mainloop()
if __name__ == '__main__':
go()