修改pytorch官方实例适用于自己的二分类迁移学习项目
本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:
- 根据AUC来迭代最优参数;
- 五折交叉验证;
- 输出验证集错误分类图片;
- 输出分类报告并保存AUC结果图片。
1 import os 2 import numpy as np 3 import torch 4 import torch.nn as nn 5 from torch.optim import lr_scheduler 6 import torchvision 7 from torchvision import datasets, models, transforms 8 from torch.utils.data import DataLoader 9 from sklearn.metrics import roc_auc_score, classification_report 10 from sklearn.model_selection import KFold 11 from torch.autograd import Variable 12 import torch.optim as optim 13 import time 14 import copy 15 import shutil 16 import sys 17 import scikitplot as skplt 18 import matplotlib.pyplot as plt 19 import pandas as pd 20 21 plt.switch_backend('agg') 22 N_CLASSES = 2 23 BATCH_SIZE = 8 24 DATA_DIR = './data' 25 LABEL_DICT = {0: 'class_1', 1: 'class_2'} 26 27 28 def imshow(inp, title=None): 29 """Imshow for Tensor.""" 30 inp = inp.numpy().transpose((1, 2, 0)) 31 mean = np.array([0.485, 0.456, 0.406]) 32 std = np.array([0.229, 0.224, 0.225]) 33 inp = std * inp + mean 34 inp = np.clip(inp, 0, 1) 35 plt.imshow(inp) 36 if title is not None: 37 plt.title(title) 38 plt.pause(100) 39 40 41 def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25): 42 since = time.time() 43 # 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换 44 best_model_wts = copy.deepcopy(model.state_dict()) 45 # best_acc = 0.0 46 # 初始auc 47 best_auc = 0.0 48 best_desc = [0, 0, None] 49 best_img_name = None 50 plt_auc = [None, None] 51 52 for epoch in range(num_epochs): 53 print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 54 print('- ' * 50) 55 56 for phase in ['train', 'val']: 57 if phase == 'train': 58 # 训练的时候进行学习率规划,其定义在下面给出 59 scheduler.step() 60 model.train(True) 61 else: 62 model.train(False) 63 phase_pred = np.array([]) 64 phase_label = np.array([]) 65 img_name = np.zeros((1, 2)) 66 prob_pred = np.zeros((1, 2)) 67 running_loss = 0.0 68 running_corrects = 0 69 # 这样迭代方便跟踪图片路径,输出错误图片名称 70 for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler): 71 inputs, labels = data 72 if use_gpu: 73 inputs = Variable(inputs.cuda()) 74 labels = Variable(labels.cuda()) 75 else: 76 inputs, labels = Variable(inputs), Variable(labels) 77 78 # 梯度参数设为0 79 optimizer.zero_grad() 80 81 # forward 82 outputs = model(inputs) 83 _, preds = torch.max(outputs.data, 1) 84 loss = criterion(outputs, labels) 85 86 # backward + 训练阶段优化 87 if phase == 'train': 88 loss.backward() 89 optimizer.step() 90 91 if phase == 'val': 92 img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0) 93 prob = outputs.data.cpu().numpy() 94 prob_pred = np.append(prob_pred, prob, axis=0) 95 96 phase_pred = np.append(phase_pred, preds.cpu().numpy()) 97 phase_label = np.append(phase_label, labels.data.cpu().numpy()) 98 running_loss += loss.item() * inputs.size(0) 99 running_corrects += torch.sum(preds == labels.data).float() 100 print() 101 epoch_loss = running_loss / dataset_sizes[phase] 102 epoch_acc = running_corrects / dataset_sizes[phase] 103 epoch_auc = roc_auc_score(phase_label, phase_pred) 104 print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format( 105 phase, epoch_loss, epoch_acc, epoch_auc)) 106 report = classification_report(phase_label, phase_pred, target_names=class_names) 107 print(report) 108 109 img_name = zip(img_name[1:], phase_pred) 110 # 当验证时遇到了更好的模型则予以保留 111 if phase == 'val' and epoch_auc > best_auc: 112 best_auc = epoch_auc 113 best_desc = epoch_acc, epoch_auc, report 114 best_img_name = img_name 115 # 深拷贝模型参数 116 best_model_wts = copy.deepcopy(model.state_dict()) 117 plt_auc = phase_label, prob_pred[1:] 118 119 print() 120 print(plt_auc[0].shape, plt_auc[1].shape) 121 csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2']) 122 csv_file['true_label'] = pd.DataFrame(plt_auc[0]) 123 csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x]) 124 csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False) 125 skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class']) 126 plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600) 127 time_elapsed = time.time() - since 128 print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 129 reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1], 130 best_desc[2]) 131 report_file.write(reports) 132 print(reports) 133 print('List the wrong judgement img ...') 134 count = 0 135 for i in best_img_name: 136 actual_label = int(i[0][1]) 137 pred_label = i[1] 138 if actual_label != pred_label: 139 tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \ 140 f'pred: {LABEL_DICT[pred_label]}' 141 print(tmp_word) 142 label_file.write(tmp_word + '\n') 143 count += 1 144 print(f'This fold has {count} wrong records ...') 145 146 # 载入最优模型参数 147 model.load_state_dict(best_model_wts) 148 return model 149 150 151 def plot_img(): 152 for i, data in enumerate(dataloaders['train']): 153 inputs, classes = data 154 out = torchvision.utils.make_grid(inputs) 155 imshow(out, title=[class_names[x] for x in classes]) 156 157 158 # 此函数可以修改适用于自己项目的图片文件名 159 def move_file(data, file_path, dir_path, root_path): 160 label_0 = 'class_2' 161 label_1 = 'class_1' 162 print(f'start copy the {file_path} file ...') 163 os.chdir(dir_path) 164 if os.path.exists(file_path): 165 print(f'Find exist {file_path} file, the file will be dropped.') 166 shutil.rmtree(os.path.join(root_path, dir_path, file_path)) 167 print(f'Finish drop the {file_path} file.') 168 169 os.mkdir(file_path) 170 tmp_path = os.path.join(os.getcwd(), file_path) 171 tmp_pre_path = os.getcwd() 172 for d in data: 173 pre_path = os.path.join(tmp_pre_path, d) 174 os.chdir(tmp_path) 175 if d[:2] == label_0: 176 if not os.path.exists(label_0): 177 os.mkdir(label_0) 178 cur_path = os.path.join(tmp_path, label_0, d) 179 shutil.copyfile(pre_path, cur_path) 180 if d[:2] == label_1: 181 if not os.path.exists(label_1): 182 os.mkdir(label_1) 183 cur_path = os.path.join(tmp_path, label_1, d) 184 shutil.copyfile(pre_path, cur_path) 185 print('finish this work ...') 186 187 188 if __name__ == "__main__": 189 if not os.path.exists('roc_img'): 190 os.mkdir('roc_img') 191 if not os.path.exists('prob_result'): 192 os.mkdir('prob_result') 193 if not os.path.exists('report'): 194 os.mkdir('report') 195 if not os.path.exists('error_record'): 196 os.mkdir('error_record') 197 if not os.path.exists('model'): 198 os.mkdir('model') 199 label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') 200 201 kf = KFold(n_splits=5, shuffle=True, random_state=1) 202 origin_path = '/home/project/' 203 dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) 204 205 for m, n in enumerate(kf.split(dd_list), start=1): 206 report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w') 207 print(f'The {m} fold for copy file and training ...') 208 move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path) 209 os.chdir(origin_path) 210 move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path) 211 os.chdir(origin_path) 212 data_transforms = { 213 'train': transforms.Compose([ 214 # 裁剪到224,224 215 transforms.RandomResizedCrop(224), 216 # 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。 217 transforms.RandomHorizontalFlip(), 218 # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及对比度变化 219 transforms.ToTensor(), 220 # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor 221 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 222 ]), 223 'val': transforms.Compose([ 224 transforms.Resize(256), 225 transforms.CenterCrop(224), 226 transforms.ToTensor(), 227 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 228 ]), 229 } 230 231 image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x), 232 data_transforms[x]) 233 for x in ['train', 'val']} 234 dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE, 235 shuffle=True, num_workers=8, pin_memory=False) 236 for x in ['train', 'val']} 237 238 dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 239 240 class_names = image_datasets['train'].classes 241 size = len(class_names) 242 print('label mapping: ') 243 print(image_datasets['train'].class_to_idx) 244 use_gpu = torch.cuda.is_available() 245 model_ft = None 246 if sys.argv[1] == 'resnet': 247 model_ft = models.resnet50(pretrained=True) 248 num_ftrs = model_ft.fc.in_features 249 model_ft.fc = nn.Sequential( 250 nn.Linear(num_ftrs, N_CLASSES), 251 nn.Sigmoid() 252 ) 253 254 # 这边可以自行把inception模型加进去 255 if sys.argv[1] == 'inception': 256 raise Exception("not provide inception model ...") 257 # model_ft = models.inception_v3(pretrained=True) 258 259 if sys.argv[1] == 'desnet': 260 model_ft = models.densenet121(pretrained=True) 261 num_ftrs = model_ft.classifier.in_features 262 model_ft.classifier = nn.Sequential( 263 nn.Linear(num_ftrs, N_CLASSES), 264 nn.Sigmoid() 265 ) 266 # use_gpu = False 267 268 if use_gpu: 269 model_ft = model_ft.cuda() 270 271 criterion = nn.CrossEntropyLoss() 272 optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) 273 # 每7个epoch衰减0.1倍 274 exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) 275 model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25) 276 print('Start save the model ...') 277 torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl') 278 print(f'The mission of the fold {m} finished.') 279 print('# '*50) 280 report_file.close() 281 label_file.close()