pytorch-day09(自定义数据集 & 迁移学习)

1、自定义数据集

 1 """
 2 自定义数据集的基础操作
 3 """
 4 import torch
 5 from torch.utils.data import Dataset
 6 
 7 class Pokemon(Dataset):  # 继承Datastet类,如自定义模型继承Module类一样
 8     def __init__(self):
 9         super(Pokemon, self).__init__()
10         pass
11 
12     def __len__(self):  
13         pass
14 
15     def __getitem__(self, idx): 
16         pass  

  例如:

 1 import torch
 2 from torch.utils.data import Dataset
 3 
 4 class NumbersDataset(Dataset):  # 继承Datastet类,如自定义模型继承Module类一样
 5     def __init__(self, training=True):
 6         if training:
 7             self.samples = list(range(1, 1001))  # 训练数据集
 8         else:
 9             self.samples = list(range(1001, 1501))
10 
11     def __len__(self):  # 返回元素的个数(数据集的个数)。
12         return len(self.samples)
13 
14     def __getitem__(self, idx):  # 类的实例对象(p),可以像p[key]取值,当实例对象做p[key]运算时,会调用__getitem__()方法。
15         return self.samples[idx]  # 返回当前具体数据,idx的最大取值为len(samples)

 数据预处理

 

  pokemon数据集:

  1 import torch
  2 import os, glob
  3 import random, csv
  4 from torch.utils.data import Dataset, DataLoader  # Dataloader:实现batch加载数据
  5 from torchvision import transforms
  6 from PIL import Image
  7 import visdom
  8 import time
  9 
 10 
 11 class Pokemon(Dataset):  # 继承Datastet类,如自定义模型继承Module类一样
 12     def __init__(self, root, resize, mode):
 13         super(Pokemon, self).__init__()
 14         self.root = root
 15         self.resize = resize
 16 
 17         self.name2lable = {}
 18         for name in sorted(os.listdir(os.path.join(root))):
 19             if not os.path.isdir(os.path.join(root, name)):
 20                 continue
 21             self.name2lable[name] = len(self.name2lable.keys())
 22 
 23         print(self.name2lable)
 24         # image, label
 25         self.images, self.labels = self.load_csv('images.csv')
 26 
 27         if mode == 'train':  # 60%
 28             self.images = self.images[:int(0.6 * len(self.images))]
 29             self.labels = self.labels[:int(0.6 * len(self.labels))]
 30         elif mode == 'val':  # 20%  60%-80%
 31             self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
 32             self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
 33         else:  # 20% test
 34             self.images = self.images[int(0.8 * len(self.images)):]
 35             self.labels = self.labels[int(0.8 * len(self.labels)):]
 36 
 37     def load_csv(self, filename):
 38         if not os.path.join(self.root, filename):
 39             images = []
 40             for name in self.name2lable.keys():
 41                 images += glob.glob(os.path.join(self.root, name, '*.png'))
 42                 images += glob.glob(os.path.join(self.root, name, '*.jpg'))
 43                 images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
 44 
 45             print(len(images), images)  # pokemon\\bulbasaur\\00000000.png
 46             random.shuffle(images)
 47             with open(os.path.join(self.root, filename), mode='w', newline='') as f:
 48                 writer = csv.writer(f)
 49                 for img in images:
 50                     name = img.split(os.sep)[-2]
 51                     label = self.name2lable[name]
 52                     writer.writerow([img, label, ])
 53                 print("writer into csv file:", filename)
 54 
 55         # read from csv file
 56         images, labels = [], []
 57         with open(os.path.join(self.root, filename)) as f:
 58             reader = csv.reader(f)
 59             for row in reader:  # pokemon\mewtwo\00000005.png,2
 60                 img, label = row  # img为第一列, label为第二列
 61                 label = int(label)
 62 
 63                 images.append(img)
 64                 labels.append(label)
 65         assert len(images) == len(labels)
 66         return images, labels
 67 
 68     def __len__(self):  # 返回元素的个数(数据集的个数)。
 69         return len(self.images)
 70 
 71     def denormalize(self, x_hat):  # 可视化的时候需要还原图片形状
 72         mean = [0.485, 0.456, 0.406]
 73         std = [0.229, 0.224, 0.225]
 74 
 75         # x_hat = (x-mean) / std
 76         # x = x_hat*std = mean
 77         # x: [c, h ,w]
 78         # mean: [3] => [3, 1, 1]
 79         mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
 80         std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
 81         x = x_hat * std + mean
 82         return x
 83 
 84     def __getitem__(self, idx):  # 类的实例对象(p),可以像p[key]取值,当实例对象做p[key]运算时,会调用__getitem__()方法。
 85         # self.images self.labels idx:[0-len(images)]
 86         img, label = self.images[idx], self.labels[idx]  # 返回当前具体数据
 87 
 88         transf = transforms.Compose([
 89             lambda x: Image.open(x).convert('RGB'),  # string path ---> image data
 90             transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
 91             transforms.RandomRotation(15),  # 旋转15度
 92             transforms.CenterCrop(self.resize),  #
 93             transforms.ToTensor(),
 94             transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 会影响图片的可视化效果,需要做denormalize
 95                                  std=[0.229, 0.224, 0.225])
 96         ])
 97 
 98         img = transf(img)
 99         label = torch.tensor(label)
100         return img, label
101 
102 
103 def main():
104     viz = visdom.Visdom()
105     pokemon = Pokemon('pokemon', 224, 'train')
106 
107     x, y = next(iter(pokemon))
108     print('samples:', x.shape, y.shape, y)  # 打印一张图片的格式 samples: torch.Size([3, 224, 224]) torch.Size([]) tensor(2)
109     # 可视化这张图片
110     # viz.images(pokemon.denormalize(x), win='samples_x', opts=dict(title='sample_xx'))
111     loder = DataLoader(pokemon, batch_size=32, shuffle=True)  # shuffle:每次去的batch是随机取的
112 
113     for x, y in loder:
114         viz.images(pokemon.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))  # nrow:每行显示8张
115         viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
116 
117         time.sleep(10)
118 
119 
120 if __name__ == '__main__':
121     main()

 使用API完成:

 1 from torch.utils.data import Dataset, DataLoader  
 2 import torchvision
 3 from torchvision import transforms
 4 import visdom
 5 import time
 6 
 7 
 8 def main():
 9     viz = visdom.Visdom()
10     transf = transforms.Compose([
11         transforms.Resize((64, 64)),
12         transforms.ToTensor()
13     ])
14 
15     db = torchvision.datasets.ImageFolder(root='pokemon', transform=transf)
16     loder = DataLoader(db, batch_size=32, shuffle=True)  # Dataloader:实现batch加载数据
17     print(db.class_to_idx)  # 打印编码 {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
18 
19     for x, y in loder:
20         viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))  # nrow:每行显示8张
21         viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
22         time.sleep(10)
23 
24 
25 if __name__ == '__main__':
26     main()

 2、创建模型

 Inherit from base class;Define forward graph。

  1 import torch
  2 from torch import nn
  3 from torch.nn import functional as F
  4 
  5 
  6 class ResBlk(nn.Module):
  7     """
  8     resnet block
  9     """
 10 
 11     def __init__(self, ch_in, ch_out, stride=1):
 12         super(ResBlk, self).__init__()
 13 
 14         # we add stride support for resbok, which is distinct from tutorials.
 15         self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
 16         self.bn1 = nn.BatchNorm2d(ch_out)
 17         self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
 18         self.bn2 = nn.BatchNorm2d(ch_out)
 19 
 20         self.extra = nn.Sequential()
 21         if ch_out != ch_in:
 22             # [b, ch_in, h, w] => [b, ch_out, h, w]
 23             self.extra = nn.Sequential(
 24                 nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),  # 1x1卷积核的作用
 25                 nn.BatchNorm2d(ch_out)
 26             )
 27 
 28     def forward(self, x):
 29         """
 30         :param x: [b, ch, h, w]
 31         :return:
 32         """
 33         out = F.relu(self.bn1(self.conv1(x)))
 34         out = self.bn2(self.conv2(out))
 35         # short cut.
 36         # extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
 37         # element-wise add:
 38         out = self.extra(x) + out
 39         out = F.relu(out)
 40 
 41         return out
 42 
 43 
 44 class ResNet18(nn.Module):
 45 
 46     def __init__(self, num_class):
 47         super(ResNet18, self).__init__()
 48 
 49         self.conv1 = nn.Sequential(
 50             nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
 51             nn.BatchNorm2d(16)
 52         )
 53         # followed 4 blocks
 54         # [b, 64, h, w] => [b, 128, h ,w]
 55         self.blk1 = ResBlk(16, 32, stride=2)
 56         # [b, 128, h, w] => [b, 256, h, w]
 57         self.blk2 = ResBlk(32, 64, stride=2)
 58         # # [b, 256, h, w] => [b, 512, h, w]
 59         self.blk3 = ResBlk(64, 128, stride=2)
 60         # # [b, 512, h, w] => [b, 1024, h, w]
 61         self.blk4 = ResBlk(128, 256, stride=2)
 62 
 63         self.outlayer = nn.Linear(256 * 2 * 2, num_class)
 64 
 65     def forward(self, x):
 66         """
 67 
 68         :param x:
 69         :return:
 70         """
 71         x = F.relu(self.conv1(x))
 72 
 73         # [b, 64, h, w] => [b, 1024, h, w]
 74         x = self.blk1(x)
 75         x = self.blk2(x)
 76         x = self.blk3(x)
 77         x = self.blk4(x)
 78 
 79         print('after conv:', x.shape)  # [b, 512, 2, 2]
 80         # [b, 512, h, w] => [b, 512, 1, 1]
 81         # x = F.adaptive_avg_pool2d(x, [1, 1])
 82         # print('after pool:', x.shape)
 83         x = x.view(x.size(0), -1)
 84         x = self.outlayer(x)
 85 
 86         return x
 87 
 88 
 89 def main():
 90     blk = ResBlk(64, 128)
 91     tmp = torch.randn(2, 64, 224, 224)
 92     out = blk(tmp)
 93     print('block:', out.shape)
 94 
 95     model = ResNet18(5)
 96     x = torch.randn(2, 3, 64, 64)
 97     out = model(x)
 98     print('resnet:', out.shape)
 99 
100     p = sum(map(lambda p: p.numel(), model.parameters()))
101     print("parameters size :", p)
102 
103 
104 if __name__ == '__main__':
105     main()

3、Train & Test

  

  1 import torch
  2 import os, glob
  3 import random, csv
  4 from torch.utils.data import Dataset, DataLoader  # Dataloader:实现batch加载数据
  5 from torchvision import transforms
  6 from PIL import Image
  7 import visdom
  8 import time
  9 
 10 
 11 class Pokemon(Dataset):  # 继承Datastet类,如自定义模型继承Module类一样
 12     def __init__(self, root, resize, mode):
 13         super(Pokemon, self).__init__()
 14         self.root = root
 15         self.resize = resize
 16 
 17         self.name2lable = {}
 18         for name in sorted(os.listdir(os.path.join(root))):
 19             if not os.path.isdir(os.path.join(root, name)):
 20                 continue
 21             self.name2lable[name] = len(self.name2lable.keys())
 22 
 23         # print(self.name2lable)
 24         # image, label
 25         self.images, self.labels = self.load_csv('images.csv')
 26 
 27         if mode == 'train':  # 60%
 28             self.images = self.images[:int(0.6 * len(self.images))]
 29             self.labels = self.labels[:int(0.6 * len(self.labels))]
 30         elif mode == 'val':  # 20%  60%-80%
 31             self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
 32             self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
 33         else:  # 20% test
 34             self.images = self.images[int(0.8 * len(self.images)):]
 35             self.labels = self.labels[int(0.8 * len(self.labels)):]
 36 
 37     def load_csv(self, filename):
 38         if not os.path.join(self.root, filename):
 39             images = []
 40             for name in self.name2lable.keys():
 41                 images += glob.glob(os.path.join(self.root, name, '*.png'))
 42                 images += glob.glob(os.path.join(self.root, name, '*.jpg'))
 43                 images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
 44 
 45             print(len(images), images)  # pokemon\\bulbasaur\\00000000.png
 46             random.shuffle(images)
 47             with open(os.path.join(self.root, filename), mode='w', newline='') as f:
 48                 writer = csv.writer(f)
 49                 for img in images:
 50                     name = img.split(os.sep)[-2]
 51                     label = self.name2lable[name]
 52                     writer.writerow([img, label, ])
 53                 print("writer into csv file:", filename)
 54 
 55         # read from csv file
 56         images, labels = [], []
 57         with open(os.path.join(self.root, filename)) as f:
 58             reader = csv.reader(f)
 59             for row in reader:  # pokemon\mewtwo\00000005.png,2
 60                 img, label = row  # img为第一列, label为第二列
 61                 label = int(label)
 62 
 63                 images.append(img)
 64                 labels.append(label)
 65         assert len(images) == len(labels)
 66         return images, labels
 67 
 68     def __len__(self):  # 返回元素的个数(数据集的个数)。
 69         return len(self.images)
 70 
 71     def denormalize(self, x_hat):  # 可视化的时候需要还原图片形状
 72         mean = [0.485, 0.456, 0.406]
 73         std = [0.229, 0.224, 0.225]
 74 
 75         # x_hat = (x-mean) / std
 76         # x = x_hat*std = mean
 77         # x: [c, h ,w]
 78         # mean: [3] => [3, 1, 1]
 79         mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
 80         std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
 81         x = x_hat * std + mean
 82         return x
 83 
 84     def __getitem__(self, idx):  # 类的实例对象(p),可以像p[key]取值,当实例对象做p[key]运算时,会调用__getitem__()方法。
 85         # self.images self.labels idx:[0-len(images)]
 86         img, label = self.images[idx], self.labels[idx]  # 返回当前具体数据
 87 
 88         transf = transforms.Compose([
 89             lambda x: Image.open(x).convert('RGB'),  # string path ---> image data
 90             transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
 91             transforms.RandomRotation(15),  # 旋转15度
 92             transforms.CenterCrop(self.resize),  #
 93             transforms.ToTensor(),
 94             transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 会影响图片的可视化效果,需要做denormalize
 95                                  std=[0.229, 0.224, 0.225])
 96         ])
 97 
 98         img = transf(img)
 99         label = torch.tensor(label)
100         return img, label
101 
102 
103 def main():
104     viz = visdom.Visdom()
105     pokemon = Pokemon('pokemon', 224, 'train')
106 
107     x, y = next(iter(pokemon))
108     print('samples:', x.shape, y.shape, y)  # 打印一张图片的格式 samples: torch.Size([3, 224, 224]) torch.Size([]) tensor(2)
109     # 可视化这张图片
110     # viz.images(pokemon.denormalize(x), win='samples_x', opts=dict(title='sample_xx'))
111     loder = DataLoader(pokemon, batch_size=32, shuffle=True)  # shuffle:每次去的batch是随机取的
112 
113     for x, y in loder:
114         viz.images(pokemon.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))  # nrow:每行显示8张
115         viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
116 
117         time.sleep(10)
118 
119 
120 if __name__ == '__main__':
121     main()
wupiao.py
  1 import torch
  2 from torch import nn
  3 from torch.nn import functional as F
  4 
  5 
  6 class ResBlk(nn.Module):
  7     """
  8     resnet block
  9     """
 10 
 11     def __init__(self, ch_in, ch_out, stride=1):
 12         super(ResBlk, self).__init__()
 13 
 14         # we add stride support for resbok, which is distinct from tutorials.
 15         self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
 16         self.bn1 = nn.BatchNorm2d(ch_out)
 17         self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
 18         self.bn2 = nn.BatchNorm2d(ch_out)
 19 
 20         self.extra = nn.Sequential()
 21         if ch_out != ch_in:
 22             # [b, ch_in, h, w] => [b, ch_out, h, w]
 23             self.extra = nn.Sequential(
 24                 nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),  # 1x1卷积核的作用
 25                 nn.BatchNorm2d(ch_out)
 26             )
 27 
 28     def forward(self, x):
 29         """
 30         :param x: [b, ch, h, w]
 31         :return:
 32         """
 33         out = F.relu(self.bn1(self.conv1(x)))
 34         out = self.bn2(self.conv2(out))
 35         # short cut.
 36         # extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
 37         # element-wise add:
 38         out = self.extra(x) + out
 39         out = F.relu(out)
 40 
 41         return out
 42 
 43 
 44 class ResNet18(nn.Module):
 45 
 46     def __init__(self, num_class):
 47         super(ResNet18, self).__init__()
 48 
 49         self.conv1 = nn.Sequential(
 50             nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
 51             nn.BatchNorm2d(16)
 52         )
 53         # followed 4 blocks
 54         # [b, 64, h, w] => [b, 128, h ,w]
 55         self.blk1 = ResBlk(16, 32, stride=3)
 56         # [b, 128, h, w] => [b, 256, h, w]
 57         self.blk2 = ResBlk(32, 64, stride=3)
 58         # # [b, 256, h, w] => [b, 512, h, w]
 59         self.blk3 = ResBlk(64, 128, stride=2)
 60         # # [b, 512, h, w] => [b, 1024, h, w]
 61         self.blk4 = ResBlk(128, 256, stride=2)
 62 
 63         self.outlayer = nn.Linear(256 * 3 * 3, num_class)
 64 
 65     def forward(self, x):
 66         """
 67 
 68         :param x:
 69         :return:
 70         """
 71         x = F.relu(self.conv1(x))
 72 
 73         # [b, 64, h, w] => [b, 1024, h, w]
 74         x = self.blk1(x)
 75         x = self.blk2(x)
 76         x = self.blk3(x)
 77         x = self.blk4(x)
 78 
 79         # print('after conv:', x.shape)  # [b, 512, 2, 2]
 80         # [b, 512, h, w] => [b, 512, 1, 1]
 81         # x = F.adaptive_avg_pool2d(x, [1, 1])
 82         # print('after pool:', x.shape)
 83         x = x.view(x.size(0), -1)
 84         x = self.outlayer(x)
 85 
 86         return x
 87 
 88 
 89 def main():
 90     blk = ResBlk(64, 128)
 91     tmp = torch.randn(2, 64, 224, 224)
 92     out = blk(tmp)
 93     print('block:', out.shape)
 94 
 95     model = ResNet18(5)
 96     x = torch.randn(2, 3, 64, 64)
 97     out = model(x)
 98     print('resnet:', out.shape)  # resnet: torch.Size([2, 5]) batch:2
 99 
100     p = sum(map(lambda p: p.numel(), model.parameters()))
101     print("parameters size :", p)
102 
103 
104 if __name__ == '__main__':
105     main()
Mymodel.py
 1 import torch
 2 from torch import optim, nn
 3 import visdom
 4 import torchvision
 5 from torch.utils.data import Dataset, DataLoader
 6 from wupiao import Pokemon
 7 from Mymodel import ResNet18
 8 
 9 batchsz = 32
10 lr = 1e-3
11 epochs = 10
12 torch.manual_seed(1234)  # 保证实验能够复现出来
13 
14 train_db = Pokemon('pokemon', 224, mode='train')
15 val_db = Pokemon('pokemon', 224, mode='val')
16 test_db = Pokemon('pokemon', 224, mode='test')
17 
18 train_loder = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
19 val_loder = DataLoader(val_db, batch_size=batchsz, num_workers=2)
20 test_loder = DataLoader(test_db, batch_size=batchsz, num_workers=2)
21 
22 viz = visdom.Visdom()
23 
24 def evalute(model, loder):  # 对validation() and test()是相同的操作
25     correct = 0
26     total = len(loder.dataset)
27     for x,y in loder:
28         with torch.no_grad():  # 只需要做前向运算
29             logits = model(x)
30             pred = logits.argmax(dim=1)
31         correct = torch.eq(pred, y).sum().float().item()
32 
33     return correct / total
34 
35 
36 
37 def main():
38     model = ResNet18(5)
39     optimizer = optim.Adam(model.parameters(), lr=lr)
40     criterion = nn.CrossEntropyLoss()
41 
42     best_acc, best_epoch = 0, 0
43     global_step = 0
44     viz.line([0], [-1], win='loss', opts=dict(title='loss'))
45     viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
46     for epoch in range(epochs):
47         for step, (x, y) in enumerate(train_loder):
48             # x:[b, 3, 224, 224] y:[b]
49 
50             logits = model(x)
51             loss = criterion(logits, y)
52 
53             optimizer.zero_grad()
54             loss.backward()
55             optimizer.step()
56 
57             viz.line([loss.item()], [global_step], win='loss', update='append')
58             global_step += 1
59         if epochs % 2 == 0:  # 做一个卷积度测试
60             val_acc = evalute(model, val_loder)
61             if val_acc > best_acc:
62                 best_acc = val_acc
63                 best_epoch = epoch
64 
65                 torch.save(model.state_dict(), 'best.mdl')
66                 viz.line([val_acc], [global_step], win='val_acc', update='append')
67 
68     print('best acc:', best_acc, 'best epoch:', best_epoch)
69     model.load_state_dict(torch.load('best.mdl'))  # 用最好的模型覆盖之前的模型
70     print('loaded from ckpt!')
71 
72     test_acc = evalute(model, test_loder)  # 使用最好的model来测试
73     print('test acc:', test_acc)
74 
75 
76 
77 if __name__ == '__main__':
78     main()
train_scratch

4、迁移学习(Transfer learning)

   

    

 1 import torch
 2 from torch import nn
 3 from matplotlib import pyplot as plt
 4 
 5 
 6 class Flatten(nn.Module):
 7     def __init__(self):
 8         super(Flatten, self).__init__()
 9 
10     def forward(self, x):
11         shape = torch.prod(torch.tensor(x.shape[1:])).item()
12         return x.view(-1, shape)
13 
14 
15 def plot_image(img, label, name):
16     fig = plt.figure()
17     for i in range(6):  # 6中照片类型
18         plt.subplot(2, 2, i + 1)
19         plt.tight_layout()
20         plt.imshow(img[i][0] * 0.308 + 0.1307, cmap='gray', interpolation='none')
21         plt.title("{} : {}".format(name, label[i].item()))
22         plt.xticks()
23         plt.yticks()
24     plt.show()
utils.py
 1 import torch
 2 from torch import optim, nn
 3 import visdom
 4 import torchvision
 5 from torch.utils.data import Dataset, DataLoader
 6 from wupiao import Pokemon
 7 # from Mymodel import ResNet18
 8 from  torchvision.models import resnet18
 9 from utils import Flatten
10 batchsz = 32
11 lr = 1e-3
12 epochs = 10
13 torch.manual_seed(1234)  # 保证实验能够复现出来
14 
15 train_db = Pokemon('pokemon', 224, mode='train')
16 val_db = Pokemon('pokemon', 224, mode='val')
17 test_db = Pokemon('pokemon', 224, mode='test')
18 
19 train_loder = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
20 val_loder = DataLoader(val_db, batch_size=batchsz, num_workers=2)
21 test_loder = DataLoader(test_db, batch_size=batchsz, num_workers=2)
22 
23 viz = visdom.Visdom()
24 
25 def evalute(model, loder):  # 对validation() and test()是相同的操作
26     correct = 0
27     total = len(loder.dataset)
28     for x,y in loder:
29         with torch.no_grad():  # 只需要做前向运算
30             logits = model(x)
31             pred = logits.argmax(dim=1)
32         correct = torch.eq(pred, y).sum().float().item()
33 
34     return correct / total
35 
36 
37 
38 def main():
39     # model = ResNet18(5)
40     train_model = resnet18(pretrained=True)
41     model = nn.Sequential(*list(train_model.children())[:-1],
42                           Flatten(), # ([b, 512, 1, 1]) => ([b, 512])
43                           nn.Linear(512, 5)
44                           )
45     # x = torch.randn(2, 3, 224, 224)
46     # print(model(x).shape)  # torch.Size([2, 512, 1, 1]) =>torch.Size([2, 5])
47 
48     optimizer = optim.Adam(model.parameters(), lr=lr)
49     criterion = nn.CrossEntropyLoss()
50 
51     best_acc, best_epoch = 0, 0
52     global_step = 0
53     viz.line([0], [-1], win='loss', opts=dict(title='loss'))
54     viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
55     for epoch in range(epochs):
56         for step, (x, y) in enumerate(train_loder):
57             # x:[b, 3, 224, 224] y:[b]
58 
59             logits = model(x)
60             loss = criterion(logits, y)
61 
62             optimizer.zero_grad()
63             loss.backward()
64             optimizer.step()
65 
66             viz.line([loss.item()], [global_step], win='loss', update='append')
67             global_step += 1
68         if epochs % 2 == 0:  # 做一个卷积度测试
69             val_acc = evalute(model, val_loder)
70             if val_acc > best_acc:
71                 best_acc = val_acc
72                 best_epoch = epoch
73 
74                 torch.save(model.state_dict(), 'best.mdl')
75                 viz.line([val_acc], [global_step], win='val_acc', update='append')
76 
77     print('best acc:', best_acc, 'best epoch:', best_epoch)
78     model.load_state_dict(torch.load('best.mdl'))  # 用最好的模型覆盖之前的模型
79     print('loaded from ckpt!')
80 
81     test_acc = evalute(model, test_loder)  # 使用最好的model来测试
82     print('test acc:', test_acc)
83 
84 
85 
86 if __name__ == '__main__':
87     main()
train_transfer.py

 

posted @ 2020-08-03 09:20  小吴的日常  阅读(308)  评论(0)    收藏  举报