import os.path
import random
import numpy as np
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
from torch import optim
datapath = ["homo", os.path.join(os.getcwd(), "input_data", "COVID"), os.path.join(os.getcwd(), "input_data", "NORMAL"),
os.path.join(os.getcwd(), "input_data", "Viral_Pneumonia")]
class my_resnet50(nn.Module):
def __init__(self):
super(my_resnet50, self).__init__()
self.backbone = torchvision.models.resnet50(pretrained=False)
self.fc2 = nn.Linear(1000, 512)
self.fc3 = nn.Linear(512, 3)
def forward(self, x):
x = self.backbone(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def process_data():
for i in range(1, 4):
dir = datapath[i]
files = os.listdir(dir)
train = open(os.path.join(os.getcwd(), "input_data", "train.txt"), "a")
test = open(os.path.join(os.getcwd(), "input_data", "test.txt"), "a")
files.sort()
idx = 0
for file in files:
if os.path.split(file)[0] == '.txt':
continue
idx += 1
if idx <= 1000:
train.write(str(dir) + '\\' + file + ' ' + str(i) + '\n')
else:
test.write(str(dir) + '\\' + file + ' ' + str(i) + '\n')
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
super(MyDataset, self).__init__()
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip('\n')
words = line.split()
if int(words[-1]) != 3:
imgs.append((words[0] + ' ' + words[1], int(words[-1])))
else:
imgs.append((words[0] + ' ' + words[1] + ' ' + words[2], int(words[-1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
transforms = transforms.Compose([
transforms.Resize([224, 224]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
if __name__ == "__main__":
BATCH_SIZE = 2048
EPOCHS = 5
LR = 3e4
process_data()
train_data = MyDataset(txt=os.path.join(os.getcwd(), "input_data", 'train.txt'), transform=transforms)
test_data = MyDataset(txt=os.path.join(os.getcwd(), "input_data", 'test.txt'), transform=transforms)
train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False, num_workers=4)
device = torch.device("cuda")
epochs = 8
lr = 1e-4
net = my_resnet50().cuda(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
train_loss = []
for epoch in range(epochs):
sum_loss = 0
for batch_idx, (x, y) in enumerate(train_loader):
x = x.to(device)
y = (y - 1).to(device)
pred = net(x)
optimizer.zero_grad()
loss = loss_func(pred, y)
loss.backward()
optimizer.step()
sum_loss += loss.item()
train_loss.append(loss.item())
print(["epoch:%d , batch:%d , loss:%.3f" % (epoch, batch_idx, loss.item())])
torch.save(net.state_dict(), os.path.join(os.getcwd(), str(epoch + 1) + '.pth'))