code

import os.path
import random
import numpy as np
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
from torch import optim

datapath = ["homo", os.path.join(os.getcwd(), "input_data", "COVID"), os.path.join(os.getcwd(), "input_data", "NORMAL"),
            os.path.join(os.getcwd(), "input_data", "Viral_Pneumonia")]


class my_resnet50(nn.Module):
    def __init__(self):
        super(my_resnet50, self).__init__()
        self.backbone = torchvision.models.resnet50(pretrained=False)
        self.fc2 = nn.Linear(1000, 512)
        self.fc3 = nn.Linear(512, 3)

    def forward(self, x):
        x = self.backbone(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


def process_data():
    for i in range(1, 4):
        dir = datapath[i]
        files = os.listdir(dir)
        train = open(os.path.join(os.getcwd(), "input_data", "train.txt"), "a")
        test = open(os.path.join(os.getcwd(), "input_data", "test.txt"), "a")
        files.sort()
        idx = 0

        for file in files:
            if os.path.split(file)[0] == '.txt':
                continue
            idx += 1
            if idx <= 1000:
                train.write(str(dir) + '\\' + file + ' ' + str(i) + '\n')
            else:
                test.write(str(dir) + '\\' + file + ' ' + str(i) + '\n')


def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        super(MyDataset, self).__init__()
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip('\n')
            words = line.split()
            if int(words[-1]) != 3:
                imgs.append((words[0] + ' ' + words[1], int(words[-1])))
            else:
                imgs.append((words[0] + ' ' + words[1] + ' ' + words[2], int(words[-1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)


transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
if __name__ == "__main__":
    BATCH_SIZE = 2048
    EPOCHS = 5
    LR = 3e4
    process_data()
    train_data = MyDataset(txt=os.path.join(os.getcwd(), "input_data", 'train.txt'), transform=transforms)
    test_data = MyDataset(txt=os.path.join(os.getcwd(), "input_data", 'test.txt'), transform=transforms)
    train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True, num_workers=4)
    test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False, num_workers=4)
    device = torch.device("cuda")

    epochs = 8
    lr = 1e-4

    net = my_resnet50().cuda(device)
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
    train_loss = []
    for epoch in range(epochs):
        sum_loss = 0
        for batch_idx, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = (y - 1).to(device)
            pred = net(x)

            optimizer.zero_grad()
            loss = loss_func(pred, y)
            loss.backward()
            optimizer.step()

            sum_loss += loss.item()
            train_loss.append(loss.item())

            print(["epoch:%d , batch:%d , loss:%.3f" % (epoch, batch_idx, loss.item())])
        torch.save(net.state_dict(), os.path.join(os.getcwd(), str(epoch + 1) + '.pth'))
posted @ 2024-03-22 06:52  Jefferyzzzz  阅读(63)  评论(0编辑  收藏  举报