pytorch-day09(自定义数据集 & 迁移学习)
1、自定义数据集
1 """ 2 自定义数据集的基础操作 3 """ 4 import torch 5 from torch.utils.data import Dataset 6 7 class Pokemon(Dataset): # 继承Datastet类,如自定义模型继承Module类一样 8 def __init__(self): 9 super(Pokemon, self).__init__() 10 pass 11 12 def __len__(self): 13 pass 14 15 def __getitem__(self, idx): 16 pass
例如:
1 import torch 2 from torch.utils.data import Dataset 3 4 class NumbersDataset(Dataset): # 继承Datastet类,如自定义模型继承Module类一样 5 def __init__(self, training=True): 6 if training: 7 self.samples = list(range(1, 1001)) # 训练数据集 8 else: 9 self.samples = list(range(1001, 1501)) 10 11 def __len__(self): # 返回元素的个数(数据集的个数)。 12 return len(self.samples) 13 14 def __getitem__(self, idx): # 类的实例对象(p),可以像p[key]取值,当实例对象做p[key]运算时,会调用__getitem__()方法。 15 return self.samples[idx] # 返回当前具体数据,idx的最大取值为len(samples)
数据预处理

pokemon数据集:
1 import torch 2 import os, glob 3 import random, csv 4 from torch.utils.data import Dataset, DataLoader # Dataloader:实现batch加载数据 5 from torchvision import transforms 6 from PIL import Image 7 import visdom 8 import time 9 10 11 class Pokemon(Dataset): # 继承Datastet类,如自定义模型继承Module类一样 12 def __init__(self, root, resize, mode): 13 super(Pokemon, self).__init__() 14 self.root = root 15 self.resize = resize 16 17 self.name2lable = {} 18 for name in sorted(os.listdir(os.path.join(root))): 19 if not os.path.isdir(os.path.join(root, name)): 20 continue 21 self.name2lable[name] = len(self.name2lable.keys()) 22 23 print(self.name2lable) 24 # image, label 25 self.images, self.labels = self.load_csv('images.csv') 26 27 if mode == 'train': # 60% 28 self.images = self.images[:int(0.6 * len(self.images))] 29 self.labels = self.labels[:int(0.6 * len(self.labels))] 30 elif mode == 'val': # 20% 60%-80% 31 self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] 32 self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))] 33 else: # 20% test 34 self.images = self.images[int(0.8 * len(self.images)):] 35 self.labels = self.labels[int(0.8 * len(self.labels)):] 36 37 def load_csv(self, filename): 38 if not os.path.join(self.root, filename): 39 images = [] 40 for name in self.name2lable.keys(): 41 images += glob.glob(os.path.join(self.root, name, '*.png')) 42 images += glob.glob(os.path.join(self.root, name, '*.jpg')) 43 images += glob.glob(os.path.join(self.root, name, '*.jpeg')) 44 45 print(len(images), images) # pokemon\\bulbasaur\\00000000.png 46 random.shuffle(images) 47 with open(os.path.join(self.root, filename), mode='w', newline='') as f: 48 writer = csv.writer(f) 49 for img in images: 50 name = img.split(os.sep)[-2] 51 label = self.name2lable[name] 52 writer.writerow([img, label, ]) 53 print("writer into csv file:", filename) 54 55 # read from csv file 56 images, labels = [], [] 57 with open(os.path.join(self.root, filename)) as f: 58 reader = csv.reader(f) 59 for row in reader: # pokemon\mewtwo\00000005.png,2 60 img, label = row # img为第一列, label为第二列 61 label = int(label) 62 63 images.append(img) 64 labels.append(label) 65 assert len(images) == len(labels) 66 return images, labels 67 68 def __len__(self): # 返回元素的个数(数据集的个数)。 69 return len(self.images) 70 71 def denormalize(self, x_hat): # 可视化的时候需要还原图片形状 72 mean = [0.485, 0.456, 0.406] 73 std = [0.229, 0.224, 0.225] 74 75 # x_hat = (x-mean) / std 76 # x = x_hat*std = mean 77 # x: [c, h ,w] 78 # mean: [3] => [3, 1, 1] 79 mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) 80 std = torch.tensor(std).unsqueeze(1).unsqueeze(1) 81 x = x_hat * std + mean 82 return x 83 84 def __getitem__(self, idx): # 类的实例对象(p),可以像p[key]取值,当实例对象做p[key]运算时,会调用__getitem__()方法。 85 # self.images self.labels idx:[0-len(images)] 86 img, label = self.images[idx], self.labels[idx] # 返回当前具体数据 87 88 transf = transforms.Compose([ 89 lambda x: Image.open(x).convert('RGB'), # string path ---> image data 90 transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))), 91 transforms.RandomRotation(15), # 旋转15度 92 transforms.CenterCrop(self.resize), # 93 transforms.ToTensor(), 94 transforms.Normalize(mean=[0.485, 0.456, 0.406], # 会影响图片的可视化效果,需要做denormalize 95 std=[0.229, 0.224, 0.225]) 96 ]) 97 98 img = transf(img) 99 label = torch.tensor(label) 100 return img, label 101 102 103 def main(): 104 viz = visdom.Visdom() 105 pokemon = Pokemon('pokemon', 224, 'train') 106 107 x, y = next(iter(pokemon)) 108 print('samples:', x.shape, y.shape, y) # 打印一张图片的格式 samples: torch.Size([3, 224, 224]) torch.Size([]) tensor(2) 109 # 可视化这张图片 110 # viz.images(pokemon.denormalize(x), win='samples_x', opts=dict(title='sample_xx')) 111 loder = DataLoader(pokemon, batch_size=32, shuffle=True) # shuffle:每次去的batch是随机取的 112 113 for x, y in loder: 114 viz.images(pokemon.denormalize(x), nrow=8, win='batch', opts=dict(title='batch')) # nrow:每行显示8张 115 viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y')) 116 117 time.sleep(10) 118 119 120 if __name__ == '__main__': 121 main()
使用API完成:
1 from torch.utils.data import Dataset, DataLoader 2 import torchvision 3 from torchvision import transforms 4 import visdom 5 import time 6 7 8 def main(): 9 viz = visdom.Visdom() 10 transf = transforms.Compose([ 11 transforms.Resize((64, 64)), 12 transforms.ToTensor() 13 ]) 14 15 db = torchvision.datasets.ImageFolder(root='pokemon', transform=transf) 16 loder = DataLoader(db, batch_size=32, shuffle=True) # Dataloader:实现batch加载数据 17 print(db.class_to_idx) # 打印编码 {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4} 18 19 for x, y in loder: 20 viz.images(x, nrow=8, win='batch', opts=dict(title='batch')) # nrow:每行显示8张 21 viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y')) 22 time.sleep(10) 23 24 25 if __name__ == '__main__': 26 main()
2、创建模型
Inherit from base class;Define forward graph。
1 import torch 2 from torch import nn 3 from torch.nn import functional as F 4 5 6 class ResBlk(nn.Module): 7 """ 8 resnet block 9 """ 10 11 def __init__(self, ch_in, ch_out, stride=1): 12 super(ResBlk, self).__init__() 13 14 # we add stride support for resbok, which is distinct from tutorials. 15 self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1) 16 self.bn1 = nn.BatchNorm2d(ch_out) 17 self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1) 18 self.bn2 = nn.BatchNorm2d(ch_out) 19 20 self.extra = nn.Sequential() 21 if ch_out != ch_in: 22 # [b, ch_in, h, w] => [b, ch_out, h, w] 23 self.extra = nn.Sequential( 24 nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride), # 1x1卷积核的作用 25 nn.BatchNorm2d(ch_out) 26 ) 27 28 def forward(self, x): 29 """ 30 :param x: [b, ch, h, w] 31 :return: 32 """ 33 out = F.relu(self.bn1(self.conv1(x))) 34 out = self.bn2(self.conv2(out)) 35 # short cut. 36 # extra module: [b, ch_in, h, w] => [b, ch_out, h, w] 37 # element-wise add: 38 out = self.extra(x) + out 39 out = F.relu(out) 40 41 return out 42 43 44 class ResNet18(nn.Module): 45 46 def __init__(self, num_class): 47 super(ResNet18, self).__init__() 48 49 self.conv1 = nn.Sequential( 50 nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0), 51 nn.BatchNorm2d(16) 52 ) 53 # followed 4 blocks 54 # [b, 64, h, w] => [b, 128, h ,w] 55 self.blk1 = ResBlk(16, 32, stride=2) 56 # [b, 128, h, w] => [b, 256, h, w] 57 self.blk2 = ResBlk(32, 64, stride=2) 58 # # [b, 256, h, w] => [b, 512, h, w] 59 self.blk3 = ResBlk(64, 128, stride=2) 60 # # [b, 512, h, w] => [b, 1024, h, w] 61 self.blk4 = ResBlk(128, 256, stride=2) 62 63 self.outlayer = nn.Linear(256 * 2 * 2, num_class) 64 65 def forward(self, x): 66 """ 67 68 :param x: 69 :return: 70 """ 71 x = F.relu(self.conv1(x)) 72 73 # [b, 64, h, w] => [b, 1024, h, w] 74 x = self.blk1(x) 75 x = self.blk2(x) 76 x = self.blk3(x) 77 x = self.blk4(x) 78 79 print('after conv:', x.shape) # [b, 512, 2, 2] 80 # [b, 512, h, w] => [b, 512, 1, 1] 81 # x = F.adaptive_avg_pool2d(x, [1, 1]) 82 # print('after pool:', x.shape) 83 x = x.view(x.size(0), -1) 84 x = self.outlayer(x) 85 86 return x 87 88 89 def main(): 90 blk = ResBlk(64, 128) 91 tmp = torch.randn(2, 64, 224, 224) 92 out = blk(tmp) 93 print('block:', out.shape) 94 95 model = ResNet18(5) 96 x = torch.randn(2, 3, 64, 64) 97 out = model(x) 98 print('resnet:', out.shape) 99 100 p = sum(map(lambda p: p.numel(), model.parameters())) 101 print("parameters size :", p) 102 103 104 if __name__ == '__main__': 105 main()
3、Train & Test

1 import torch 2 import os, glob 3 import random, csv 4 from torch.utils.data import Dataset, DataLoader # Dataloader:实现batch加载数据 5 from torchvision import transforms 6 from PIL import Image 7 import visdom 8 import time 9 10 11 class Pokemon(Dataset): # 继承Datastet类,如自定义模型继承Module类一样 12 def __init__(self, root, resize, mode): 13 super(Pokemon, self).__init__() 14 self.root = root 15 self.resize = resize 16 17 self.name2lable = {} 18 for name in sorted(os.listdir(os.path.join(root))): 19 if not os.path.isdir(os.path.join(root, name)): 20 continue 21 self.name2lable[name] = len(self.name2lable.keys()) 22 23 # print(self.name2lable) 24 # image, label 25 self.images, self.labels = self.load_csv('images.csv') 26 27 if mode == 'train': # 60% 28 self.images = self.images[:int(0.6 * len(self.images))] 29 self.labels = self.labels[:int(0.6 * len(self.labels))] 30 elif mode == 'val': # 20% 60%-80% 31 self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] 32 self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))] 33 else: # 20% test 34 self.images = self.images[int(0.8 * len(self.images)):] 35 self.labels = self.labels[int(0.8 * len(self.labels)):] 36 37 def load_csv(self, filename): 38 if not os.path.join(self.root, filename): 39 images = [] 40 for name in self.name2lable.keys(): 41 images += glob.glob(os.path.join(self.root, name, '*.png')) 42 images += glob.glob(os.path.join(self.root, name, '*.jpg')) 43 images += glob.glob(os.path.join(self.root, name, '*.jpeg')) 44 45 print(len(images), images) # pokemon\\bulbasaur\\00000000.png 46 random.shuffle(images) 47 with open(os.path.join(self.root, filename), mode='w', newline='') as f: 48 writer = csv.writer(f) 49 for img in images: 50 name = img.split(os.sep)[-2] 51 label = self.name2lable[name] 52 writer.writerow([img, label, ]) 53 print("writer into csv file:", filename) 54 55 # read from csv file 56 images, labels = [], [] 57 with open(os.path.join(self.root, filename)) as f: 58 reader = csv.reader(f) 59 for row in reader: # pokemon\mewtwo\00000005.png,2 60 img, label = row # img为第一列, label为第二列 61 label = int(label) 62 63 images.append(img) 64 labels.append(label) 65 assert len(images) == len(labels) 66 return images, labels 67 68 def __len__(self): # 返回元素的个数(数据集的个数)。 69 return len(self.images) 70 71 def denormalize(self, x_hat): # 可视化的时候需要还原图片形状 72 mean = [0.485, 0.456, 0.406] 73 std = [0.229, 0.224, 0.225] 74 75 # x_hat = (x-mean) / std 76 # x = x_hat*std = mean 77 # x: [c, h ,w] 78 # mean: [3] => [3, 1, 1] 79 mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) 80 std = torch.tensor(std).unsqueeze(1).unsqueeze(1) 81 x = x_hat * std + mean 82 return x 83 84 def __getitem__(self, idx): # 类的实例对象(p),可以像p[key]取值,当实例对象做p[key]运算时,会调用__getitem__()方法。 85 # self.images self.labels idx:[0-len(images)] 86 img, label = self.images[idx], self.labels[idx] # 返回当前具体数据 87 88 transf = transforms.Compose([ 89 lambda x: Image.open(x).convert('RGB'), # string path ---> image data 90 transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))), 91 transforms.RandomRotation(15), # 旋转15度 92 transforms.CenterCrop(self.resize), # 93 transforms.ToTensor(), 94 transforms.Normalize(mean=[0.485, 0.456, 0.406], # 会影响图片的可视化效果,需要做denormalize 95 std=[0.229, 0.224, 0.225]) 96 ]) 97 98 img = transf(img) 99 label = torch.tensor(label) 100 return img, label 101 102 103 def main(): 104 viz = visdom.Visdom() 105 pokemon = Pokemon('pokemon', 224, 'train') 106 107 x, y = next(iter(pokemon)) 108 print('samples:', x.shape, y.shape, y) # 打印一张图片的格式 samples: torch.Size([3, 224, 224]) torch.Size([]) tensor(2) 109 # 可视化这张图片 110 # viz.images(pokemon.denormalize(x), win='samples_x', opts=dict(title='sample_xx')) 111 loder = DataLoader(pokemon, batch_size=32, shuffle=True) # shuffle:每次去的batch是随机取的 112 113 for x, y in loder: 114 viz.images(pokemon.denormalize(x), nrow=8, win='batch', opts=dict(title='batch')) # nrow:每行显示8张 115 viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y')) 116 117 time.sleep(10) 118 119 120 if __name__ == '__main__': 121 main()
1 import torch 2 from torch import nn 3 from torch.nn import functional as F 4 5 6 class ResBlk(nn.Module): 7 """ 8 resnet block 9 """ 10 11 def __init__(self, ch_in, ch_out, stride=1): 12 super(ResBlk, self).__init__() 13 14 # we add stride support for resbok, which is distinct from tutorials. 15 self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1) 16 self.bn1 = nn.BatchNorm2d(ch_out) 17 self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1) 18 self.bn2 = nn.BatchNorm2d(ch_out) 19 20 self.extra = nn.Sequential() 21 if ch_out != ch_in: 22 # [b, ch_in, h, w] => [b, ch_out, h, w] 23 self.extra = nn.Sequential( 24 nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride), # 1x1卷积核的作用 25 nn.BatchNorm2d(ch_out) 26 ) 27 28 def forward(self, x): 29 """ 30 :param x: [b, ch, h, w] 31 :return: 32 """ 33 out = F.relu(self.bn1(self.conv1(x))) 34 out = self.bn2(self.conv2(out)) 35 # short cut. 36 # extra module: [b, ch_in, h, w] => [b, ch_out, h, w] 37 # element-wise add: 38 out = self.extra(x) + out 39 out = F.relu(out) 40 41 return out 42 43 44 class ResNet18(nn.Module): 45 46 def __init__(self, num_class): 47 super(ResNet18, self).__init__() 48 49 self.conv1 = nn.Sequential( 50 nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0), 51 nn.BatchNorm2d(16) 52 ) 53 # followed 4 blocks 54 # [b, 64, h, w] => [b, 128, h ,w] 55 self.blk1 = ResBlk(16, 32, stride=3) 56 # [b, 128, h, w] => [b, 256, h, w] 57 self.blk2 = ResBlk(32, 64, stride=3) 58 # # [b, 256, h, w] => [b, 512, h, w] 59 self.blk3 = ResBlk(64, 128, stride=2) 60 # # [b, 512, h, w] => [b, 1024, h, w] 61 self.blk4 = ResBlk(128, 256, stride=2) 62 63 self.outlayer = nn.Linear(256 * 3 * 3, num_class) 64 65 def forward(self, x): 66 """ 67 68 :param x: 69 :return: 70 """ 71 x = F.relu(self.conv1(x)) 72 73 # [b, 64, h, w] => [b, 1024, h, w] 74 x = self.blk1(x) 75 x = self.blk2(x) 76 x = self.blk3(x) 77 x = self.blk4(x) 78 79 # print('after conv:', x.shape) # [b, 512, 2, 2] 80 # [b, 512, h, w] => [b, 512, 1, 1] 81 # x = F.adaptive_avg_pool2d(x, [1, 1]) 82 # print('after pool:', x.shape) 83 x = x.view(x.size(0), -1) 84 x = self.outlayer(x) 85 86 return x 87 88 89 def main(): 90 blk = ResBlk(64, 128) 91 tmp = torch.randn(2, 64, 224, 224) 92 out = blk(tmp) 93 print('block:', out.shape) 94 95 model = ResNet18(5) 96 x = torch.randn(2, 3, 64, 64) 97 out = model(x) 98 print('resnet:', out.shape) # resnet: torch.Size([2, 5]) batch:2 99 100 p = sum(map(lambda p: p.numel(), model.parameters())) 101 print("parameters size :", p) 102 103 104 if __name__ == '__main__': 105 main()
1 import torch 2 from torch import optim, nn 3 import visdom 4 import torchvision 5 from torch.utils.data import Dataset, DataLoader 6 from wupiao import Pokemon 7 from Mymodel import ResNet18 8 9 batchsz = 32 10 lr = 1e-3 11 epochs = 10 12 torch.manual_seed(1234) # 保证实验能够复现出来 13 14 train_db = Pokemon('pokemon', 224, mode='train') 15 val_db = Pokemon('pokemon', 224, mode='val') 16 test_db = Pokemon('pokemon', 224, mode='test') 17 18 train_loder = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4) 19 val_loder = DataLoader(val_db, batch_size=batchsz, num_workers=2) 20 test_loder = DataLoader(test_db, batch_size=batchsz, num_workers=2) 21 22 viz = visdom.Visdom() 23 24 def evalute(model, loder): # 对validation() and test()是相同的操作 25 correct = 0 26 total = len(loder.dataset) 27 for x,y in loder: 28 with torch.no_grad(): # 只需要做前向运算 29 logits = model(x) 30 pred = logits.argmax(dim=1) 31 correct = torch.eq(pred, y).sum().float().item() 32 33 return correct / total 34 35 36 37 def main(): 38 model = ResNet18(5) 39 optimizer = optim.Adam(model.parameters(), lr=lr) 40 criterion = nn.CrossEntropyLoss() 41 42 best_acc, best_epoch = 0, 0 43 global_step = 0 44 viz.line([0], [-1], win='loss', opts=dict(title='loss')) 45 viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc')) 46 for epoch in range(epochs): 47 for step, (x, y) in enumerate(train_loder): 48 # x:[b, 3, 224, 224] y:[b] 49 50 logits = model(x) 51 loss = criterion(logits, y) 52 53 optimizer.zero_grad() 54 loss.backward() 55 optimizer.step() 56 57 viz.line([loss.item()], [global_step], win='loss', update='append') 58 global_step += 1 59 if epochs % 2 == 0: # 做一个卷积度测试 60 val_acc = evalute(model, val_loder) 61 if val_acc > best_acc: 62 best_acc = val_acc 63 best_epoch = epoch 64 65 torch.save(model.state_dict(), 'best.mdl') 66 viz.line([val_acc], [global_step], win='val_acc', update='append') 67 68 print('best acc:', best_acc, 'best epoch:', best_epoch) 69 model.load_state_dict(torch.load('best.mdl')) # 用最好的模型覆盖之前的模型 70 print('loaded from ckpt!') 71 72 test_acc = evalute(model, test_loder) # 使用最好的model来测试 73 print('test acc:', test_acc) 74 75 76 77 if __name__ == '__main__': 78 main()
4、迁移学习(Transfer learning)


1 import torch 2 from torch import nn 3 from matplotlib import pyplot as plt 4 5 6 class Flatten(nn.Module): 7 def __init__(self): 8 super(Flatten, self).__init__() 9 10 def forward(self, x): 11 shape = torch.prod(torch.tensor(x.shape[1:])).item() 12 return x.view(-1, shape) 13 14 15 def plot_image(img, label, name): 16 fig = plt.figure() 17 for i in range(6): # 6中照片类型 18 plt.subplot(2, 2, i + 1) 19 plt.tight_layout() 20 plt.imshow(img[i][0] * 0.308 + 0.1307, cmap='gray', interpolation='none') 21 plt.title("{} : {}".format(name, label[i].item())) 22 plt.xticks() 23 plt.yticks() 24 plt.show()
1 import torch 2 from torch import optim, nn 3 import visdom 4 import torchvision 5 from torch.utils.data import Dataset, DataLoader 6 from wupiao import Pokemon 7 # from Mymodel import ResNet18 8 from torchvision.models import resnet18 9 from utils import Flatten 10 batchsz = 32 11 lr = 1e-3 12 epochs = 10 13 torch.manual_seed(1234) # 保证实验能够复现出来 14 15 train_db = Pokemon('pokemon', 224, mode='train') 16 val_db = Pokemon('pokemon', 224, mode='val') 17 test_db = Pokemon('pokemon', 224, mode='test') 18 19 train_loder = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4) 20 val_loder = DataLoader(val_db, batch_size=batchsz, num_workers=2) 21 test_loder = DataLoader(test_db, batch_size=batchsz, num_workers=2) 22 23 viz = visdom.Visdom() 24 25 def evalute(model, loder): # 对validation() and test()是相同的操作 26 correct = 0 27 total = len(loder.dataset) 28 for x,y in loder: 29 with torch.no_grad(): # 只需要做前向运算 30 logits = model(x) 31 pred = logits.argmax(dim=1) 32 correct = torch.eq(pred, y).sum().float().item() 33 34 return correct / total 35 36 37 38 def main(): 39 # model = ResNet18(5) 40 train_model = resnet18(pretrained=True) 41 model = nn.Sequential(*list(train_model.children())[:-1], 42 Flatten(), # ([b, 512, 1, 1]) => ([b, 512]) 43 nn.Linear(512, 5) 44 ) 45 # x = torch.randn(2, 3, 224, 224) 46 # print(model(x).shape) # torch.Size([2, 512, 1, 1]) =>torch.Size([2, 5]) 47 48 optimizer = optim.Adam(model.parameters(), lr=lr) 49 criterion = nn.CrossEntropyLoss() 50 51 best_acc, best_epoch = 0, 0 52 global_step = 0 53 viz.line([0], [-1], win='loss', opts=dict(title='loss')) 54 viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc')) 55 for epoch in range(epochs): 56 for step, (x, y) in enumerate(train_loder): 57 # x:[b, 3, 224, 224] y:[b] 58 59 logits = model(x) 60 loss = criterion(logits, y) 61 62 optimizer.zero_grad() 63 loss.backward() 64 optimizer.step() 65 66 viz.line([loss.item()], [global_step], win='loss', update='append') 67 global_step += 1 68 if epochs % 2 == 0: # 做一个卷积度测试 69 val_acc = evalute(model, val_loder) 70 if val_acc > best_acc: 71 best_acc = val_acc 72 best_epoch = epoch 73 74 torch.save(model.state_dict(), 'best.mdl') 75 viz.line([val_acc], [global_step], win='val_acc', update='append') 76 77 print('best acc:', best_acc, 'best epoch:', best_epoch) 78 model.load_state_dict(torch.load('best.mdl')) # 用最好的模型覆盖之前的模型 79 print('loaded from ckpt!') 80 81 test_acc = evalute(model, test_loder) # 使用最好的model来测试 82 print('test acc:', test_acc) 83 84 85 86 if __name__ == '__main__': 87 main()

浙公网安备 33010602011771号