pytorch实践:dog VS cat
猫狗分类,练手级代码,与手写数字识别相比,主要修改的地方是输出全连接层,将输出通道由10(十个数字)改成2(猫狗二分类)。还有一个是对数据集处理,因pytorch没有内置数据集函数,因此图片要自己处理。
数据要用opencv处理,归一化。
数据集:data __train__Cat
| |__Dog
|__test__Cat
|__Dog
get_data.py
import os import cv2 import time from torchvision import transforms import torch trans=transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((.5,.5,.5),(.5,.5,.5)) ] ) DATA_PATH = './data/' PIC_SIZE = 32 def get_files(): train_data = [] test_data = [] train_cat_path = DATA_PATH + 'train/Cat/' train_dog_path = DATA_PATH + 'train/Dog/' test_cat_path = DATA_PATH + 'test/Cat/' test_dog_path = DATA_PATH + 'test/Dog/' print('now,loading data.due to The amount of data is huge,you have to wait minutes') start_time=temp_time=time.time() for file in os.listdir(train_cat_path): image=cv2.imread(train_cat_path+file) try: image=cv2.resize(image, (PIC_SIZE, PIC_SIZE)) train_data.append([image,0]) except BaseException: os.remove(train_cat_path+file) # print('无效的图片:%s' % file) finally: if time.time()-temp_time > 20: temp_time=time.time() print('Take %d seconds'%(time.time()-start_time)) for file in os.listdir(train_dog_path): image = cv2.imread(train_dog_path + file) try: image=cv2.resize(image, (PIC_SIZE, PIC_SIZE)) train_data.append([image,1]) except BaseException: os.remove(train_dog_path + file) # print('无效的图片:%s' % file) finally: if time.time() - temp_time > 20: temp_time = time.time() print('Take %d seconds' % (time.time() - start_time)) for file in os.listdir(test_cat_path): image = cv2.imread(test_cat_path + file) try: image = cv2.resize(image, (PIC_SIZE, PIC_SIZE)) test_data.append([image,0]) except BaseException: os.remove(test_cat_path + file) # print('无效的图片:%s' % file) finally: if time.time() - temp_time > 20: temp_time = time.time() print('Take %d seconds' % (time.time() - start_time)) for file in os.listdir(test_dog_path): image = cv2.imread(test_dog_path + file) try: image = cv2.resize(image, (PIC_SIZE, PIC_SIZE)) test_data.append([image,1]) except BaseException: os.remove(test_dog_path + file) # print('无效的图片:%s' % file) finally: if time.time() - temp_time > 20: temp_time = time.time() print('Take %d seconds' % (time.time() - start_time)) for img in train_data: img[0]=trans(img[0]) for img in test_data: img[0]=trans(img[0]) print('have loaded the data:\nThere are %d train_data\nThere are %d test_data' %(len(train_data), len(test_data))) print('-----------------------------------------------------------------------------') return train_data,test_data if __name__ == '__main__': torch.save(get_files(),"data.pyd")
将数据集写到data.pyd
然后训练,测试。
dogVScat.py
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import numpy as np LR = 0.01 MOM = 0.5 EPOCHES=100 BATCHSIZE=50 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(in_channels=3,out_channels=10,kernel_size=3) self.conv2 = nn.Conv2d(10,20,3) self.conv3 = nn.Conv2d(20,10,3) self.mp = nn.MaxPool2d(2) self.fc = nn.Linear(40,2) def forward(self,x): in_size = x.size(0) x = F.relu(self.mp(self.conv1(x))) x = F.relu(self.mp(self.conv2(x))) x = F.relu(self.mp(self.conv3(x))) x = x.view(in_size,-1) x = self.fc(x) return F.log_softmax(x,dim=1) def train(): xbatch = [] ybatch = [] for i, (x, y) in enumerate(train_data): xbatch.append(x) ybatch.append(y) if (i+1) % BATCHSIZE == 0: xbatch = torch.stack(xbatch) #convert list of tensor into tensor ybatch = torch.Tensor(ybatch).long() out = model(xbatch) loss = F.nll_loss(out, ybatch) xbatch = [] ybatch = [] optimizer.zero_grad() loss.backward() optimizer.step() # print(str(epoch)+" epoch has Completed training") # torch.save(model,str(epoch)+".pkl") def test(epoch): test_loss = 0 correct = 0 xbatch = [] ybatch = [] for i,(x,y) in enumerate(test_data): xbatch.append(x) ybatch.append(y) if (i+1) % BATCHSIZE == 0: xbatch = torch.stack(xbatch) #convert list of tensor into tensor ybatch = torch.Tensor(ybatch).long() output = model(xbatch) pred=torch.max(output,1)[1] correct +=pred.eq(ybatch).sum(0).numpy() # test_loss += F.nll_loss(output, ybatch).data[0] xbatch = [] ybatch = [] print('correct of epoch {} is {:.2f}%'.format(epoch,correct/len(test_data)*100)) if __name__ == '__main__': model = Net() optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOM) train_data, test_data = torch.load("data.pyd") np.random.shuffle(train_data) for epoch in range(EPOCHES): train() test(epoch)
训练结果:
correct of epoch 0 is 52.33% correct of epoch 1 is 54.84% correct of epoch 2 is 55.95% correct of epoch 3 is 56.59% correct of epoch 4 is 57.57% correct of epoch 5 is 60.50% correct of epoch 6 is 62.18% correct of epoch 7 is 63.81% correct of epoch 8 is 64.46% correct of epoch 9 is 65.24% correct of epoch 10 is 65.93% correct of epoch 11 is 66.55% correct of epoch 12 is 67.47% correct of epoch 13 is 68.45% correct of epoch 14 is 69.00% correct of epoch 15 is 69.62% correct of epoch 16 is 69.99% correct of epoch 17 is 70.58% correct of epoch 18 is 71.10% correct of epoch 19 is 71.42% correct of epoch 20 is 71.87% correct of epoch 21 is 72.31% correct of epoch 22 is 72.36% correct of epoch 23 is 72.76% correct of epoch 24 is 73.01% correct of epoch 25 is 73.32% correct of epoch 26 is 73.36% correct of epoch 27 is 73.51% correct of epoch 28 is 73.17% correct of epoch 29 is 73.38% correct of epoch 30 is 73.50% correct of epoch 31 is 73.73% correct of epoch 32 is 73.93% correct of epoch 33 is 74.15% correct of epoch 34 is 74.11% correct of epoch 35 is 74.22% correct of epoch 36 is 74.26% correct of epoch 37 is 74.07% correct of epoch 38 is 74.12% correct of epoch 39 is 74.35% correct of epoch 40 is 74.38% correct of epoch 41 is 74.44% correct of epoch 42 is 74.17% correct of epoch 43 is 74.19% correct of epoch 44 is 74.30% correct of epoch 45 is 74.61% correct of epoch 46 is 74.64% correct of epoch 47 is 74.54% correct of epoch 48 is 74.58% correct of epoch 49 is 74.59% correct of epoch 50 is 74.59% correct of epoch 51 is 74.53% correct of epoch 52 is 74.45% correct of epoch 53 is 74.43% correct of epoch 54 is 74.43% correct of epoch 55 is 74.41% correct of epoch 56 is 74.42% correct of epoch 57 is 74.52% correct of epoch 58 is 74.48% correct of epoch 59 is 74.34% correct of epoch 60 is 74.21% correct of epoch 61 is 74.16% correct of epoch 62 is 74.15% correct of epoch 63 is 74.25% correct of epoch 64 is 74.11% correct of epoch 65 is 73.95% correct of epoch 66 is 73.85% correct of epoch 67 is 73.99% correct of epoch 68 is 74.15% correct of epoch 69 is 74.05% correct of epoch 70 is 74.05% correct of epoch 71 is 74.34% correct of epoch 72 is 74.21% correct of epoch 73 is 74.14% correct of epoch 74 is 73.98% correct of epoch 75 is 73.87% correct of epoch 76 is 73.88% correct of epoch 77 is 73.85% correct of epoch 78 is 73.84% correct of epoch 79 is 73.84% correct of epoch 80 is 73.65% correct of epoch 81 is 73.66% correct of epoch 82 is 73.43% correct of epoch 83 is 73.36% correct of epoch 84 is 73.30% correct of epoch 85 is 73.12% correct of epoch 86 is 73.20% correct of epoch 87 is 73.22% correct of epoch 88 is 73.13% correct of epoch 89 is 73.16% correct of epoch 90 is 73.17% correct of epoch 91 is 72.99% correct of epoch 92 is 73.09% correct of epoch 93 is 73.02% correct of epoch 94 is 72.80% correct of epoch 95 is 72.98% correct of epoch 96 is 72.73% correct of epoch 97 is 72.80% correct of epoch 98 is 72.76% correct of epoch 99 is 72.68%
最高准确率为74.6%