(原)人脸姿态时别HYBRID COARSE-FINE CLASSIFICATION FOR HEAD POSE ESTIMATION

转载请注明出处:

https://www.cnblogs.com/darkknightzh/p/12150128.html

论文:

HYBRID COARSE-FINE CLASSIFICATION FOR HEAD POSE ESTIMATION

论文网址:

https://arxiv.org/abs/1901.06778

官方pytorch代码:

https://github.com/haofanwang/accurate-head-pose

 

该论文提出了coarse-fine的分类方式。

 

1. 网络结构

论文网络结构如下图所示。输入图像通过骨干网络得到特征后,分别连接到不同的fc层。这些fc层将输入特征映射到-99度至102度之内不同间隔的角度区间(间隔分别为1,3,11,33,99),而后通过softmax得到归一化特征,并分2支,一方面计算期望及期望与真值的MSE loss,另一方面计算交叉熵损失。而后求和,得到最终的损失。

1)     MSE lossdeep head pose中接近(区别是此处使用198个类别的分类结果计算期望,deep head pose使用66个类别)。

2)     其他角度区间(除198个类别的角度区间之外)只用于计算交叉熵损失(如下图所示)。

3)     不同角度区间的交叉熵损失权重不同。

4)     本文MSE损失的权重较大(为2

5)     训练时使用softmax计算概率。测试时使用带temperaturesoftmax计算概率(由于代码中T=1,实际上等效于softmax)。

6)     https://arxiv.org/abs/1503.02531可知,给定输入logit ${{z}_{i}}$,其softmax temperature的输出${{q}_{i}}$计算如下:

${{q}_{i}}=\frac{\exp ({{z}_{i}}/T)}{\sum\nolimits_{j}{\exp ({{z}_{j}}/T)}}$

其中Ttemperature。通常设置为1(即为softmax)。T越大,输出概率的差异性越小;T越小(越接近0),输出概率的差异性越大。

因而,感觉上图变成下面这样,会更容易理解:

本文损失函数如下:

$Loss=\alpha \centerdot MSE(y,{{y}^{*}})+\sum\limits_{i=1}^{num}{{{\beta }_{i}}\centerdot H({{y}_{i}},y_{i}^{*})}$

其中H代表交叉熵损失。${{\beta }_{i}}$为不同角度区间时交叉熵损失的权重(具体权重可参见代码)。

2. 代码

2.1 网络结构

 1 class Multinet(nn.Module):
 2     # Hopenet with 3 output layers for yaw, pitch and roll
 3     # Predicts Euler angles by binning and regression with the expected value
 4     def __init__(self, block, layers, num_bins):
 5         self.inplanes = 64
 6         super(Multinet, self).__init__()
 7         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
 8         self.bn1 = nn.BatchNorm2d(64)
 9         self.relu = nn.ReLU(inplace=True)
10         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
11         self.layer1 = self._make_layer(block, 64, layers[0])
12         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
13         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
14         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
15         self.avgpool = nn.AvgPool2d(7)    # 至此为Resnet的骨干网络
16         self.fc_yaw = nn.Linear(512 * block.expansion, num_bins)     #  和hopenet类似,只是num_bins=198
17         self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)   #  和hopenet类似,只是num_bins=198
18         self.fc_roll = nn.Linear(512 * block.expansion, num_bins)    #  和hopenet类似,只是num_bins=198
19         
20         self.fc_yaw_1 = nn.Linear(512 * block.expansion, 66)   # 66和deep head pose一致
21         self.fc_yaw_2 = nn.Linear(512 * block.expansion, 18)   # 其他为新的fc层
22         self.fc_yaw_3 = nn.Linear(512 * block.expansion, 6)
23         self.fc_yaw_4 = nn.Linear(512 * block.expansion, 2)
24         
25         self.fc_pitch_1 = nn.Linear(512 * block.expansion, 66)
26         self.fc_pitch_2 = nn.Linear(512 * block.expansion, 18)
27         self.fc_pitch_3 = nn.Linear(512 * block.expansion, 6)
28         self.fc_pitch_4 = nn.Linear(512 * block.expansion, 2)
29         
30         self.fc_roll_1 = nn.Linear(512 * block.expansion, 66)
31         self.fc_roll_2 = nn.Linear(512 * block.expansion, 18)
32         self.fc_roll_3 = nn.Linear(512 * block.expansion, 6)
33         self.fc_roll_4 = nn.Linear(512 * block.expansion, 2)
34 
35         # Vestigial layer from previous experiments
36         self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)  # 未使用
37 
38         for m in self.modules():
39             if isinstance(m, nn.Conv2d):
40                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41                 m.weight.data.normal_(0, math.sqrt(2. / n))
42             elif isinstance(m, nn.BatchNorm2d):
43                 m.weight.data.fill_(1)
44                 m.bias.data.zero_()
45 
46     def _make_layer(self, block, planes, blocks, stride=1):
47         downsample = None
48         if stride != 1 or self.inplanes != planes * block.expansion:
49             downsample = nn.Sequential(
50                 nn.Conv2d(self.inplanes, planes * block.expansion,
51                           kernel_size=1, stride=stride, bias=False),
52                 nn.BatchNorm2d(planes * block.expansion),
53             )
54 
55         layers = []
56         layers.append(block(self.inplanes, planes, stride, downsample))
57         self.inplanes = planes * block.expansion
58         for i in range(1, blocks):
59             layers.append(block(self.inplanes, planes))
60 
61         return nn.Sequential(*layers)
62 
63     def forward(self, x):
64         x = self.conv1(x)
65         x = self.bn1(x)
66         x = self.relu(x)
67         x = self.maxpool(x)
68 
69         x = self.layer1(x)
70         x = self.layer2(x)
71         x = self.layer3(x)
72         x = self.layer4(x)
73 
74         x = self.avgpool(x)
75         x = x.view(x.size(0), -1)  # 得到骨干网络的特征
76         pre_yaw = self.fc_yaw(x)     # 以下得到yaw、pitch、roll等的其他特征
77         pre_pitch = self.fc_pitch(x)
78         pre_roll = self.fc_roll(x)
79         
80         pre_yaw_1 = self.fc_yaw_1(x)
81         pre_pitch_1 = self.fc_pitch_1(x)
82         pre_roll_1 = self.fc_roll_1(x)
83         
84         pre_yaw_2 = self.fc_yaw_2(x)
85         pre_pitch_2 = self.fc_pitch_2(x)
86         pre_roll_2 = self.fc_roll_2(x)
87         
88         pre_yaw_3 = self.fc_yaw_3(x)
89         pre_pitch_3 = self.fc_pitch_3(x)
90         pre_roll_3 = self.fc_roll_3(x)
91         
92         pre_yaw_4 = self.fc_yaw_4(x)
93         pre_pitch_4 = self.fc_pitch_4(x)
94         pre_roll_4 = self.fc_roll_4(x)
95 
96         return pre_yaw,pre_yaw_1,pre_yaw_2,pre_yaw_3,pre_yaw_4, pre_pitch,pre_pitch_1,pre_pitch_2,pre_pitch_3,pre_pitch_4, pre_roll,pre_roll_1,pre_roll_2,pre_roll_3,pre_roll_4
View Code

2.2 训练代码

  1 def parse_args():
  2     """Parse input arguments."""
  3     parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
  4     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int)
  5     parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.', default=25, type=int)
  6     parser.add_argument('--batch_size', dest='batch_size', help='Batch size.', default=32, type=int)
  7     parser.add_argument('--lr', dest='lr', help='Base learning rate.', default=0.000001, type=float)
  8     parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='AFLW_multi', type=str)
  9     parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.', default='', type=str)
 10     parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.', default='/tools/AFLW_train.txt', type=str)
 11     parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.', default = '', type=str)
 12     parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.', default=2, type=float)
 13     parser.add_argument('--snapshot', dest='snapshot', help='Path of model snapshot.', default='', type=str)
 14 
 15     args = parser.parse_args()
 16     return args
 17 
 18 def get_ignored_params(model):
 19     # Generator function that yields ignored params.
 20     b = [model.conv1, model.bn1, model.fc_finetune]
 21     for i in range(len(b)):
 22         for module_name, module in b[i].named_modules():
 23             if 'bn' in module_name:
 24                 module.eval()
 25             for name, param in module.named_parameters():
 26                 yield param
 27 
 28 def get_non_ignored_params(model):
 29     # Generator function that yields params that will be optimized.
 30     b = [model.layer1, model.layer2, model.layer3, model.layer4]
 31     for i in range(len(b)):
 32         for module_name, module in b[i].named_modules():
 33             if 'bn' in module_name:
 34                 module.eval()
 35             for name, param in module.named_parameters():
 36                 yield param
 37 
 38 def get_fc_params(model):
 39     # Generator function that yields fc layer params.
 40     b = [model.fc_yaw, model.fc_pitch, model.fc_roll,
 41          model.fc_yaw_1, model.fc_pitch_1, model.fc_roll_1,
 42          model.fc_yaw_2, model.fc_pitch_2, model.fc_roll_2,
 43          model.fc_yaw_3, model.fc_pitch_3, model.fc_roll_3]
 44     for i in range(len(b)):
 45         for module_name, module in b[i].named_modules():
 46             for name, param in module.named_parameters():
 47                 yield param
 48 
 49 def load_filtered_state_dict(model, snapshot):
 50     # By user apaszke from discuss.pytorch.org
 51     model_dict = model.state_dict()
 52     snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
 53     model_dict.update(snapshot)
 54     model.load_state_dict(model_dict)
 55 
 56 if __name__ == '__main__':
 57     args = parse_args()
 58 
 59     cudnn.enabled = True
 60     num_epochs = args.num_epochs
 61     batch_size = args.batch_size
 62     gpu = args.gpu_id
 63 
 64     if not os.path.exists('output/snapshots'):
 65         os.makedirs('output/snapshots')
 66 
 67     # ResNet50 structure
 68     model = hopenet.Multinet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 198)   # 载入模型
 69 
 70     if args.snapshot == '':
 71         load_filtered_state_dict(model, model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth'))
 72     else:
 73         saved_state_dict = torch.load(args.snapshot)
 74         model.load_state_dict(saved_state_dict)
 75 
 76     print('Loading data.')
 77 
 78     transformations = transforms.Compose([transforms.Resize(240),
 79     transforms.RandomCrop(224), transforms.ToTensor(),
 80     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 81 
 82     if args.dataset == 'Pose_300W_LP':
 83         pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
 84     elif args.dataset == 'Pose_300W_LP_multi':
 85         pose_dataset = datasets.Pose_300W_LP_multi(args.data_dir, args.filename_list, transformations)
 86     elif args.dataset == 'Pose_300W_LP_random_ds':
 87         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
 88     elif args.dataset == 'Synhead':
 89         pose_dataset = datasets.Synhead(args.data_dir, args.filename_list, transformations)
 90     elif args.dataset == 'AFLW2000':
 91         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
 92     elif args.dataset == 'BIWI':
 93         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
 94     elif args.dataset == 'BIWI_multi':
 95         pose_dataset = datasets.BIWI_multi(args.data_dir, args.filename_list, transformations)
 96     elif args.dataset == 'AFLW':
 97         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
 98     elif args.dataset == 'AFLW_multi':        # 载入数据的dataset
 99         pose_dataset = datasets.AFLW_multi(args.data_dir, args.filename_list, transformations)
100     elif args.dataset == 'AFLW_aug':
101         pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
102     elif args.dataset == 'AFW':
103         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
104     else:
105         print('Error: not a valid dataset name')
106         sys.exit()
107 
108     train_loader = torch.utils.data.DataLoader(dataset=pose_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
109     
110     model.cuda(gpu)
111     criterion = nn.CrossEntropyLoss().cuda(gpu)
112     reg_criterion = nn.MSELoss().cuda(gpu)
113     # Regression loss coefficient
114     alpha = args.alpha
115 
116     softmax = nn.Softmax(dim=1).cuda(gpu)
117     idx_tensor = [idx for idx in range(198)]
118     idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
119 
120     optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
121                                   {'params': get_non_ignored_params(model), 'lr': args.lr},
122                                   {'params': get_fc_params(model), 'lr': args.lr * 5}],
123                                    lr = args.lr)
124 
125     print('Ready to train network.')
126     for epoch in range(num_epochs):
127         for i, (images, labels, labels_0, labels_1, labels_2, labels_3, cont_labels, name) in enumerate(train_loader):
128             images = Variable(images).cuda(gpu)
129             
130             # Binned labels
131             label_yaw = Variable(labels[:,0]).cuda(gpu)
132             label_pitch = Variable(labels[:,1]).cuda(gpu)
133             label_roll = Variable(labels[:,2]).cuda(gpu)
134             
135             label_yaw_1 = Variable(labels_0[:,0]).cuda(gpu)
136             label_pitch_1 = Variable(labels_0[:,1]).cuda(gpu)
137             label_roll_1 = Variable(labels_0[:,2]).cuda(gpu)
138             
139             label_yaw_2 = Variable(labels_1[:,0]).cuda(gpu)
140             label_pitch_2 = Variable(labels_1[:,1]).cuda(gpu)
141             label_roll_2 = Variable(labels_1[:,2]).cuda(gpu)
142             
143             label_yaw_3 = Variable(labels_2[:,0]).cuda(gpu)
144             label_pitch_3 = Variable(labels_2[:,1]).cuda(gpu)
145             label_roll_3 = Variable(labels_2[:,2]).cuda(gpu)
146             
147             label_yaw_4 = Variable(labels_3[:,0]).cuda(gpu)
148             label_pitch_4 = Variable(labels_3[:,1]).cuda(gpu)
149             label_roll_4 = Variable(labels_3[:,2]).cuda(gpu)
150                         
151             # Continuous labels
152             label_yaw_cont = Variable(cont_labels[:,0]).cuda(gpu)
153             label_pitch_cont = Variable(cont_labels[:,1]).cuda(gpu)
154             label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu)
155 
156             # Forward pass
157             yaw,yaw_1,yaw_2,yaw_3,yaw_4, pitch,pitch_1,pitch_2,pitch_3,pitch_4, roll,roll_1,roll_2,roll_3,roll_4 = model(images)     # 得到各个特征
158 
159             # Cross entropy loss  # 各个交叉熵损失
160             loss_yaw,loss_yaw_1,loss_yaw_2,loss_yaw_3,loss_yaw_4 = criterion(yaw, label_yaw),criterion(yaw_1, label_yaw_1),criterion(yaw_2, label_yaw_2),criterion(yaw_3, label_yaw_3),criterion(yaw_4, label_yaw_4)
161             loss_pitch,loss_pitch_1,loss_pitch_2,loss_pitch_3,loss_pitch_4 = criterion(pitch, label_pitch),criterion(pitch_1, label_pitch_1),criterion(pitch_2, label_pitch_2),criterion(pitch_3, label_pitch_3),criterion(pitch_4, label_pitch_4)
162             loss_roll,loss_roll_1,loss_roll_2,loss_roll_3,loss_roll_4 = criterion(roll, label_roll),criterion(roll_1, label_roll_1),criterion(roll_2, label_roll_2),criterion(roll_3, label_roll_3),criterion(roll_4, label_roll_4)
163 
164             # MSE loss  # 归一化特征
165             yaw_predicted = softmax(yaw)
166             pitch_predicted = softmax(pitch)
167             roll_predicted = softmax(roll)
168 
169             yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) - 99  # 此部分和deep head pose计算一致
170             pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) - 99
171             roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) - 99
172                         
173             loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont)  # 此部分和deep head pose计算一致
174             loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont)
175             loss_reg_roll = reg_criterion(roll_predicted, label_roll_cont)
176 
177             # Total loss
178             total_loss_yaw = alpha * loss_reg_yaw + 7*loss_yaw + 5*loss_yaw_1 + 3*loss_yaw_2 + 1*loss_yaw_3 + 1*loss_yaw_4  # 各个角度区间的加权总损失
179             total_loss_pitch = alpha * loss_reg_pitch + 7*loss_pitch + 5*loss_pitch_1 + 3*loss_pitch_2 + 1*loss_pitch_3 + 1*loss_pitch_4
180             total_loss_roll = alpha * loss_reg_roll + 7*loss_roll + 5*loss_roll_1 + 3*loss_roll_2 + 1*loss_roll_3 + 1*loss_pitch_4
181             
182             loss_seq = [total_loss_yaw, total_loss_pitch, total_loss_roll]
183             grad_seq = [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
184             optimizer.zero_grad()
185             torch.autograd.backward(loss_seq, grad_seq)
186             optimizer.step()
187             
188             if (i+1) % 100 == 0:
189                 print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
190                        %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, total_loss_yaw.item(), total_loss_pitch.item(), total_loss_roll.item()))
191         # Save models at numbered epochs.
192         if epoch % 1 == 0 and epoch < num_epochs:
193             print('Taking snapshot...')
194             torch.save(model.state_dict(),
195             'output/snapshots/' + args.output_string + '_epoch_'+ str(epoch+1) + '.pkl')
View Code

2.3 测试代码

  1 def parse_args():
  2     """Parse input arguments."""
  3     parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
  4     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int)
  5     parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.', default='/AFLW2000/', type=str)
  6     parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
  7           default='/tools/AFLW2000_filename_filtered.txt', type=str)
  8     parser.add_argument('--snapshot', dest='snapshot', help='Name of model snapshot.',
  9           default='/output/snapshots/AFLW2000/_epoch_9.pkl', type=str)
 10     parser.add_argument('--batch_size', dest='batch_size', help='Batch size.', default=1, type=int)
 11     parser.add_argument('--save_viz', dest='save_viz', help='Save images with pose cube.', default=False, type=bool)
 12     parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='AFLW2000', type=str)
 13 
 14     args = parser.parse_args()
 15 
 16     return args
 17 
 18 if __name__ == '__main__':
 19     args = parse_args()
 20 
 21     cudnn.enabled = True
 22     gpu = args.gpu_id
 23     snapshot_path = args.snapshot
 24 
 25     # ResNet50 structure
 26     model = hopenet.Multinet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 198)
 27 
 28     print('Loading snapshot.')
 29     # Load snapshot
 30     saved_state_dict = torch.load(snapshot_path)
 31     model.load_state_dict(saved_state_dict)
 32 
 33     print('Loading data.')
 34 
 35     transformations = transforms.Compose([transforms.Resize(224),
 36     transforms.CenterCrop(224), transforms.ToTensor(),
 37     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 38 
 39     if args.dataset == 'Pose_300W_LP':
 40         pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
 41     elif args.dataset == 'Pose_300W_LP_multi':
 42         pose_dataset = datasets.Pose_300W_LP_multi(args.data_dir, args.filename_list, transformations)
 43     elif args.dataset == 'Pose_300W_LP_random_ds':
 44         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
 45     elif args.dataset == 'Synhead':
 46         pose_dataset = datasets.Synhead(args.data_dir, args.filename_list, transformations)
 47     elif args.dataset == 'AFLW2000':
 48         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
 49     elif args.dataset == 'BIWI':
 50         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
 51     elif args.dataset == 'BIWI_multi':
 52         pose_dataset = datasets.BIWI_multi(args.data_dir, args.filename_list, transformations)
 53     elif args.dataset == 'AFLW':
 54         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
 55     elif args.dataset == 'AFLW_multi':
 56         pose_dataset = datasets.AFLW_multi(args.data_dir, args.filename_list, transformations)
 57     elif args.dataset == 'AFLW_aug':
 58         pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
 59     elif args.dataset == 'AFW':
 60         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
 61     else:
 62         print('Error: not a valid dataset name')
 63         sys.exit()
 64     test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
 65                                                batch_size=args.batch_size,
 66                                                num_workers=2)
 67 
 68     model.cuda(gpu)
 69 
 70     print('Ready to test network.')
 71 
 72     # Test the Model
 73     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
 74     total = 0
 75 
 76     idx_tensor = [idx for idx in range(198)]
 77     idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
 78 
 79     yaw_error = .0
 80     pitch_error = .0
 81     roll_error = .0
 82 
 83     l1loss = torch.nn.L1Loss(size_average=False)
 84     for i, (images, labels, cont_labels, name) in enumerate(test_loader):
 85     #for i, (images, labels, labels_0, labels_1, labels_2, labels_3, cont_labels, name) in enumerate(test_loader):
 86         images = Variable(images).cuda(gpu)
 87         total += cont_labels.size(0)
 88 
 89         label_yaw = cont_labels[:,0].float()
 90         label_pitch = cont_labels[:,1].float()
 91         label_roll = cont_labels[:,2].float()
 92         
 93         yaw,yaw_1,yaw_2,yaw_3,yaw_4, pitch,pitch_1,pitch_2,pitch_3,pitch_4, roll,roll_1,roll_2,roll_3,roll_4 = model(images)  # 得到特征
 94 
 95         # Binned predictions
 96         _, yaw_bpred = torch.max(yaw.data, 1)
 97         _, pitch_bpred = torch.max(pitch.data, 1)
 98         _, roll_bpred = torch.max(roll.data, 1)
 99 
100         # Continuous predictions
101         yaw_predicted = utils.softmax_temperature(yaw.data, 1)  # 带temperature的softmax
102         pitch_predicted = utils.softmax_temperature(pitch.data, 1)
103         roll_predicted = utils.softmax_temperature(roll.data, 1)
104 
105         yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() - 99     # 计算期望
106         pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() - 99
107         roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu() - 99
108 
109         # Mean absolute error
110         yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw))
111         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
112         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
113 
114         # Save first image in batch with pose cube or axis.
115         if args.save_viz:
116             name = name[0]
117             if args.dataset == 'BIWI':
118                 cv2_img = cv2.imread(os.path.join(args.data_dir, name + '_rgb.png'))
119             else:
120                 cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
121             if args.batch_size == 1:
122                 error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll)))
123                 cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=2)
124             # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100)
125             utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100)
126             cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
127 
128     print('Test error in degrees of the model on the ' + str(total) +
129     ' test images. Yaw: %.4f, Pitch: %.4f, Roll: %.4f, MAE: %.4f' % (yaw_error / total,
130     pitch_error / total, roll_error / total, (yaw_error+pitch_error+roll_error)/(3.0*total)))
View Code

2.4 softmax_temperature代码

1 def softmax_temperature(tensor, temperature):
2     result = torch.exp(tensor / temperature)
3     result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))  # 带temperature的softmax
4     return result
View Code

 

posted on 2020-01-04 19:36  darkknightzh  阅读(951)  评论(0编辑  收藏  举报

导航