code

import torch
import torchvision.models

from PIL import ImageTk, Image

from torchvision import transforms
from torch.autograd import Variable
import torch.nn as nn
import tkinter.filedialog
import tkinter as tk
import windnd

root = None


def selectPath():
    path_ = tk.filedialog.askopenfilename()
    do(path_)

def next():
    root.destroy()
    go()
def do(path):
    device = torch.device('cuda:0')
    model = my_resnet50().to(device)
    model.load_state_dict(torch.load("8.pth"), False)
    img = Image.open(path)
    imgg = img.resize((300, 300))
    photo = ImageTk.PhotoImage(imgg)
    ph = tk.Label(root, image=photo)
    ph.place(relx=0.25, rely=0.1)
    # root.mainloop()
    img = img.convert('RGB')
    val = transforms.Compose([transforms.Resize([224, 224]), transforms.ToTensor()])
    img_tensor = val(img)
    img_tensor = Variable(torch.unsqueeze(img_tensor, dim=0).float(), requires_grad=False).to(device)
    model.eval()
    with torch.no_grad():
        output = model(img_tensor)
        o = torch.softmax(output, dim=1)
        pred_val, pred_index = torch.max(o, 1)
        pred_val = pred_val.detach().cpu().numpy()
        pred_index = pred_index.detach().cpu().numpy()
        classes = ["新冠肺炎", "正常肺部", "病毒性肺炎"]

        test = tk.Label(root,
                        text="预测结果为:" + classes[pred_index[0]] + "\n" + "概率为:" + str(pred_val[0] * 100) + "%",
                        font=("黑体", 20))
        test.place(relx=0.2, rely=0.7)
        tb = tk.Button(root, text='返回', command=next)
        tb.place(relx=0.4, rely=0.8, relwidth=0.2, relheight=0.1)
        root.mainloop()


class my_resnet50(nn.Module):
    def __init__(self):
        super(my_resnet50, self).__init__()
        self.backbone = torchvision.models.resnet50(pretrained=False)
        self.fc2 = nn.Linear(1000, 512)
        self.fc3 = nn.Linear(512, 3)

    def forward(self, x):
        x = self.backbone(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


def go():
    global root
    root = tk.Tk()
    root.title('AI肺炎类型诊断')
    root.geometry('600x600')

    lb = tk.Label(root, text='请选择待诊断的图片', font=("黑体", 20), pady=20)
    lb.pack()
    tb = tk.Button(root, text='选择图片', command=selectPath)
    tb.place(relx=0.4, rely=0.1, relwidth=0.2, relheight=0.1)
    root.mainloop()


if __name__ == '__main__':
    go()
posted @ 2024-06-25 02:24  Jefferyzzzz  阅读(65)  评论(0编辑  收藏  举报