深度学习(PYTORCH)-3.sphereface-pytorch.lfw_eval.py详解
Posted on 2018-03-07 16:29 LOMOoO 阅读(2581) 评论(0) 编辑 收藏 举报pytorch版本sphereface的原作者地址:https://github.com/clcarwin/sphereface_pytorch
由于接触深度学习不久,所以花了较长时间来阅读源码,以下对项目中的lfw_eval.py文件做了详细解释
(不知是版本问题还是作者code有误,原代码存在很多的bug,需要自行一一纠正,另:由于在windows下运行,故而去掉了gpu加速以及多线程)
1 #-*- coding:utf-8 -*- 2 from __future__ import print_function 3 4 import torch 5 import torch.nn as nn 6 import torch.optim as optim 7 import torch.nn.functional as F 8 from torch.autograd import Variable 9 torch.backends.cudnn.bencmark = True 10 11 import os,sys,cv2,random,datetime 12 import argparse 13 import numpy as np 14 import zipfile 15 16 from dataset import ImageDataset 17 from matlab_cp2tform import get_similarity_transform_for_cv2 18 import net_sphere 19 from matplotlib import pyplot as plt 20 21 #图像对齐和裁剪 22 def alignment(src_img,src_pts): 23 #使用标准人脸坐标对图像进行仿射 24 ref_pts = [ [30.2946, 51.6963],[65.5318, 51.5014], 25 [48.0252, 71.7366],[33.5493, 92.3655],[62.7299, 92.2041] ] 26 crop_size = (96, 112) 27 src_pts = np.array(src_pts).reshape(5,2) 28 29 s = np.array(src_pts).astype(np.float32) 30 r = np.array(ref_pts).astype(np.float32) 31 32 tfm = get_similarity_transform_for_cv2(s, r) 33 face_img = cv2.warpAffine(src_img, tfm, crop_size) 34 return face_img 35 36 #k-fold cross validation(k-折叠交叉验证) 37 #将n份数据分为n_folds份,以次将第i份作为测试集,其余部分作为训练集 38 def KFold(n=200, n_folds=10, shuffle=False): 39 folds = [] 40 base = list(range(n)) 41 for i in range(n_folds): 42 test = base[(i*n//n_folds):((i+1)*n//n_folds)] 43 train = list(set(base)-set(test)) 44 folds.append([train,test]) 45 return folds 46 47 #求解当前阈值时的准确率 48 def eval_acc(threshold, diff): 49 y_true = [] 50 y_predict = [] 51 for d in diff: 52 same = 1 if float(d[2]) > threshold else 0 53 y_predict.append(same) 54 y_true.append(int(d[3])) 55 y_true = np.array(y_true) 56 y_predict = np.array(y_predict) 57 accuracy = 1.0*np.count_nonzero(y_true==y_predict)/len(y_true) 58 return accuracy 59 60 #eval_acc和find_best_threshold共同工作,来求试图找到最佳阈值, 61 # 62 def find_best_threshold(thresholds, predicts): 63 #threshould 阈值 64 best_threshold = best_acc = 0 65 for threshold in thresholds: 66 accuracy = eval_acc(threshold, predicts) 67 if accuracy >= best_acc: 68 best_acc = accuracy 69 best_threshold = threshold 70 return best_threshold 71 72 73 #命令行参数 74 parser = argparse.ArgumentParser(description='PyTorch sphereface lfw') 75 parser.add_argument('--net','-n', default='sphere20a', type=str) 76 parser.add_argument('--lfw', default='../DataSet/lfw.zip', type=str) 77 parser.add_argument('--model','-m', default='./sphere20a_20171020.pth', type=str) 78 args = parser.parse_args() 79 80 predicts=[] 81 82 #加载网络 83 net = getattr(net_sphere,args.net)() 84 #加载模型 85 net.load_state_dict(torch.load(args.model)) 86 # 87 net.eval() 88 # 89 net.feature = True 90 91 #加载图片数据 92 zfile = zipfile.ZipFile(args.lfw) 93 94 #加载landmark,每张照片包括五个特征点,共五组坐标 95 landmark = {} 96 with open('data/lfw_landmark.txt') as f: 97 landmark_lines = f.readlines() 98 #对每一行进行处理 99 for line in landmark_lines: 100 l = line.replace('\n','').split('\t') 101 #将每一组数据转化为字典形式 102 landmark[l[0]] = [int(k) for k in l[1:]] 103 104 #加载pairs 105 with open('data/pairs.txt') as f: 106 pairs_lines = f.readlines()[1:] 107 108 #range表示测试的图片对数 109 for i in range(600): 110 print(str(i)+" start") 111 p = pairs_lines[i].replace('\n','').split('\t') 112 # pairs.txt一共有6000行,存在两种形式, 113 # 分别表示进行对比的两张照片,形式1是同一个人,形式2是不同人: 114 # name 数字1 数字2 115 # name 数字1 name数字2 116 if 3==len(p): 117 sameflag = 1 118 #形式例如:Woody_Allen/Woody_Allen_0002.jpg 119 name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1])) 120 name2 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[2])) 121 if 4==len(p): 122 sameflag = 0 123 name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1])) 124 name2 = p[2]+'/'+p[2]+'_'+'{:04}.jpg'.format(int(p[3])) 125 126 #分别加载两张照片,并对其进行图像对齐 127 org_img1=cv2.imdecode(np.frombuffer(zfile.read("lfw/lfw/"+name1),np.uint8),1) 128 org_img2=cv2.imdecode(np.frombuffer(zfile.read("lfw/lfw/"+name2),np.uint8),1) 129 img1 = alignment(org_img1,landmark[name1]) 130 img2 = alignment(org_img2,landmark[name2]) 131 #1.对输出图像使用cv2进行展示 132 # cv2.imshow("org_img1", org_img1) 133 # cv2.imshow("org_img2", org_img2) 134 # cv2.imshow("img1",img1) 135 # cv2.imshow("img2", img2) 136 # cv2.waitKey(0) 137 # cv2.destroyAllWindows() 138 #2.对输出图像使用matplotlib进行展示 139 fig_new=plt.figure() 140 img_list=[[org_img1,221],[org_img2,222],[img1,223],[img2,224]] 141 for p,q in img_list: 142 ax=fig_new.add_subplot(q) 143 p = p[:, :, (2, 1, 0)] 144 ax.imshow(p) 145 plt.show() 146 147 #cv.flip图像翻转,第二个参数:1:水平翻转,0:垂直翻转,-1:水平垂直翻转 148 imglist = [img1,cv2.flip(img1,1),img2,cv2.flip(img2,1)] 149 #分别对图片进行 150 for m in range(len(imglist)): 151 imglist[m] = imglist[m].transpose(2, 0, 1).reshape((1,3,112,96)) 152 imglist[m] = (imglist[m]-127.5)/128.0 153 154 # p.vstack: 垂直(按照行顺序)的把数组给堆叠起来 155 #******举例****** 156 # import numpy as np 157 # a = [1, 2, 3] 158 # b = [4, 5, 6] 159 # print(np.vstack((a, b))) 160 # 161 # 输出: 162 # [[1 2 3] 163 # [4 5 6]] 164 img = np.vstack(imglist) 165 #将numpy形式转化为variable形式 166 img = Variable(torch.from_numpy(img).float(),volatile=True) 167 output = net(img) 168 #得到计算结果,f1和f2均为512维向量形式 169 f = output.data 170 f1,f2 = f[0],f[2] 171 #计算二者的余弦相似度,后面加上常量是为了防止分母为0 172 #关于余弦相似度请自行百度或google 173 #这里给出一个简单说明的链接:http://blog.csdn.net/huangfei711/article/details/78469614 174 #a*b/|a||b| 175 cosdistance = f1.dot(f2)/(f1.norm()*f2.norm()+1e-5) 176 predicts.append('{}\t{}\t{}\t{}\n'.format(name1,name2,cosdistance,sameflag)) 177 print(str(i) + " end") 178 179 180 #准确率 181 accuracy = [] 182 #(最佳)阈值 183 thd = [] 184 #k-fold cross validation(k-折叠交叉验证) 185 #folds的形式为[[train,test],[train,test].....] 186 folds = KFold(n=600, n_folds=10, shuffle=False) 187 #取数组为-1到1,步长为0.005 188 thresholds = np.arange(-1.0, 1.0, 0.005) 189 # 此处为原作者code,疑似有误,已做修改 190 # predicts = np.array(map(lambda line:frd.append(line.strip('\n').split()), predicts)) 191 predicts = np.array([k.strip('\n').split() for k in predicts]) 192 for idx, (train, test) in enumerate(folds): 193 # predicts[train/test]形式为: 194 # [['Doris_Roberts/Doris_Roberts_0001.jpg' 195 # 'Doris_Roberts/Doris_Roberts_0003.jpg' '0.6532696413605743' '1'],.....] 196 #寻找最佳阈值 197 best_thresh = find_best_threshold(thresholds, predicts[train]) 198 #通过上面的得到的最佳阈值来对test数据集进行测试得到准确率 199 accuracy.append(eval_acc(best_thresh, predicts[test])) 200 #thd阈值 201 thd.append(best_thresh) 202 #np.mean:计算均值,np.std:计算标准差 203 #输出结果分别为:准确率均值,准确率标准差,阈值均值 204 print('LFWACC={:.4f} std={:.4f} thd={:.4f}'.format(np.mean(accuracy), np.std(accuracy), np.mean(thd))) 205 #例如结果为 LFWACC=0.9800 std=0.0600 thd=0.3490 206 #则说明准确率为98%,准确率标准差为0.06,阈值的均值为0.3490 207 #因此我们可以认为余弦相似度大于0.3490的两张图片里是同一个人