[论文理解] SSR-Net: A Compact Soft Stagewise Regression Network for Age Estimation
SSR-Net: A Compact Soft Stagewise Regression Network for Age Estimation
简介
本文提出了一种年龄回归的方法,定义了由粗到细进行classification的过程,并且取得了不错的实验效果。文章指出,regression-based年龄估计的方法很容易出现过拟合现象,因此很难得到广泛应用,这是因为年龄的随机性和从人脸特征到年龄特征的映射很庞大的原因,因此很难建立从人脸特征到连续空间的庞大映射而不过拟合,所以采用从人脸特征到离散空间的映射是一个好方法。但是,单纯把年龄识别任务当成一个多分类分组任务又会导致学习到的分类结果太硬,也就是说,不能体现个分组之间的相关性,而是把各个分组当成不相关的类别进行分类的,这其实不太正确。此外,其他方法往往需要大量额外信息,本文提出的方法只需要脸部信息,而不需要其他额外信息就可以完成建模。
Stagewise Regression
本节主要讲述作者由DEX方法得到灵感,如何定义年龄识别这个回归任务。
首先,把要分类的年龄Y分为s个bin
那么每个bin的宽度就是
并且定义
为第i个bin的representive age,具体就是一个代表值,比如[30,34]区间我们可以取中间值代替,只是举个例子。
那么如何由每个bin的representive age得到predicted age呢?
这需要我们网络学习到一个年龄分布p
使得
其中y就是预测的年龄值,网络学习的就是向量p。
但是由于上面的模型呢参数量还是有点多,因为DEX中每个年龄一个类别,这样会导致全连接的参数非常多,所以本文提出了由粗到细的结构,分为多个stage,每个stage只做简单分类,比如第一个stage只做一个三分类,能够划分人是儿童、青年、老年人。然后再再下一个stage继续细分。
如果网络有k个stage,那么最终的输出就可以定义为
Dynamic Range
作者认为年龄group比较大的时候,年龄bin之间没有重叠,会显得不太灵活,就是体现不了各个bin之间相关性,所以这里设计了dynamic range。一个shift一个scale。
首先是scale,这是为了能动态改变bin的width。
先验认为第k个stage分为sk个bin,那么scale之后变为
delta k 是可以学习的参数。
然后是shift。
shift部分作者通过改变bin的索引实现。认为第k个stage有sk个代表向量,所以就需要sk个偏移量来偏移他们的索引。
而shift和scale操作都依赖于特定输入,也就是说,对于给定的输入,shift和scale是定的,这就能让网络帮助我们学习到如何去划分stage,划分多少个stage,这一点也是很重要的。
网络结构
网络结构大致如图:
如上所说,网络分为两个stream,两个stream使用的激活函数不同,为了获取不同的特征。每个stage都是为了得到三个训练参数,得到delta的过程就是先1*1卷积,然后relu,pooling,两个stream pooling后的结果对应相乘然后fc到1再用tanh限定到-1到1得到delta。
p和eta得到的过程如图类似。比较简单。
论文复现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SSRNet(nn.Module):
def __init__(self):
super(SSRNet,self).__init__()
self.stream1_k1 = nn.Sequential(
nn.Conv2d(3,32,kernel_size = 3),
nn.BatchNorm2d(32),
nn.ReLU(inplace = True),
nn.AvgPool2d(2)
)
self.stream2_k1 = nn.Sequential(
nn.Conv2d(3,16,kernel_size = 3),
nn.BatchNorm2d(16),
nn.Tanh(),
nn.MaxPool2d(2)
)
self.stream1_k2 = nn.Sequential(
nn.Conv2d(32,32,kernel_size = 3),
nn.BatchNorm2d(32),
nn.ReLU(inplace = True),
nn.AvgPool2d(2)
)
self.stream2_k2 = nn.Sequential(
nn.Conv2d(16,16,kernel_size = 3),
nn.BatchNorm2d(16),
nn.Tanh(),
nn.MaxPool2d(2)
)
self.stream1_k3 = nn.Sequential(
nn.Conv2d(32,32,kernel_size = 3),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Conv2d(32,32,kernel_size = 3),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.AvgPool2d(2)
)
self.stream2_k3 = nn.Sequential(
nn.Conv2d(16,16,kernel_size = 3),
nn.BatchNorm2d(16),
nn.Tanh(),
nn.MaxPool2d(2),
nn.Conv2d(16,16,kernel_size = 3),
nn.BatchNorm2d(16),
nn.Tanh(),
nn.MaxPool2d(2),
)
self.s1_pre1 = nn.Sequential(
nn.Conv2d(32,32,kernel_size = 1),
nn.ReLU(inplace= True),
nn.MaxPool2d(31)
)
self.s2_pre1 = nn.Sequential(
nn.Conv2d(16,32,kernel_size = 1),
nn.ReLU(inplace= True),
nn.AvgPool2d(31)
)
self.s1_pre2 = nn.Sequential(
nn.Conv2d(32,32,kernel_size = 1),
nn.ReLU(inplace= True),
nn.MaxPool2d(14)
)
self.s2_pre2 = nn.Sequential(
nn.Conv2d(16,32,kernel_size = 1),
nn.ReLU(inplace= True),
nn.AvgPool2d(14)
)
self.s1_pre3 = nn.Sequential(
nn.Conv2d(32,32,kernel_size = 1),
nn.ReLU(inplace= True),
nn.MaxPool2d(2)
)
self.s2_pre3 = nn.Sequential(
nn.Conv2d(16,32,kernel_size = 1),
nn.ReLU(inplace= True),
nn.AvgPool2d(2)
)
self.delta_k1 = nn.Sequential(
nn.Linear(32,1),
nn.Tanh()
)
self.delta_k2 = nn.Sequential(
nn.Linear(32,1),
nn.Tanh()
)
self.delta_k3 = nn.Sequential(
nn.Linear(32,1),
nn.Tanh()
)
self.vec_p1 = nn.Sequential(
nn.Linear(16,3),
nn.ReLU()
)
self.vec_p2 = nn.Sequential(
nn.Linear(16,3),
nn.ReLU()
)
self.vec_p3 = nn.Sequential(
nn.Linear(16,3),
nn.ReLU()
)
self.eta1 = nn.Sequential(
nn.Linear(16,3),
nn.Tanh()
)
self.eta2 = nn.Sequential(
nn.Linear(16,3),
nn.Tanh()
)
self.eta3 = nn.Sequential(
nn.Linear(16,3),
nn.Tanh()
)
self.s1_dr1 = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(32,16),
nn.ReLU()
)
self.s2_dr1 = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(32,16),
nn.ReLU()
)
self.s1_dr2 = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(32,16),
nn.ReLU()
)
self.s2_dr2 = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(32,16),
nn.ReLU()
)
self.s1_dr3 = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(32,16),
nn.ReLU()
)
self.s2_dr3 = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(32,16),
nn.ReLU()
)
self.idx1 = torch.tensor(range(3),dtype=torch.float)
self.idx2 = torch.tensor(range(3),dtype=torch.float)
self.idx3 = torch.tensor(range(3),dtype=torch.float)
def forward(self,x):
t1 = self.stream1_k1(x)
#torch.Size([2, 32, 31, 31])
t2 = self.stream1_k2(t1)
t3 = self.stream1_k3(t2)
t4 = self.stream2_k1(x)
t5 = self.stream2_k2(t4)
t6 = self.stream2_k3(t5)
#print(self.s1_pre1(t1).shape)
#print(self.s2_pre1(t4).shape)
#torch.Size([2, 32, 15, 15])
#torch.Size([2, 32, 15, 15])
#print((self.s1_pre1(t1) * self.s2_pre1(t4)).shape)
#torch.Size([2, 32, 15, 15])
delta_k_1 = self.delta_k1((self.s1_pre1(t1) * self.s2_pre1(t4)).view(-1,32))
delta_k_2 = self.delta_k2((self.s1_pre2(t2) * self.s2_pre2(t5)).view(-1,32))
delta_k_3 = self.delta_k3((self.s1_pre3(t3) * self.s2_pre3(t6)).view(-1,32))
#print(self.s1_pre1(t1).shape)
#print(self.s1_dr1(self.s1_pre1(t1).view(-1,32)).shape)
vec_p_1 = self.vec_p1(self.s1_dr1(self.s1_pre1(t1).view(-1,32)) * self.s2_dr1(self.s2_pre1(t4).view(-1,32)))
#print(vec_p_1.shape)
vec_p_2 = self.vec_p2(self.s1_dr2(self.s1_pre2(t2).view(-1,32)) * self.s2_dr2(self.s2_pre2(t5).view(-1,32)))
vec_p_3 = self.vec_p3(self.s1_dr3(self.s1_pre3(t3).view(-1,32)) * self.s2_dr3(self.s2_pre3(t6).view(-1,32)))
eta1 = self.eta1(self.s1_dr1(self.s1_pre1(t1).view(-1,32)) * self.s2_dr1(self.s2_pre1(t4).view(-1,32)))
eta2 = self.eta2(self.s1_dr2(self.s1_pre2(t2).view(-1,32)) * self.s2_dr2(self.s2_pre2(t5).view(-1,32)))
eta3 = self.eta3(self.s1_dr3(self.s1_pre3(t3).view(-1,32)) * self.s2_dr3(self.s2_pre3(t6).view(-1,32)))
# 可以理解为 30 + 5 + 0.5 = 35.5 这么去从粗到细预测
output1 = torch.sum((self.idx1.view(1,3) + eta1) * vec_p_1 / 3 / (1 + delta_k_1),dim = 1)
output2 = torch.sum((self.idx2.view(1,3) + eta2) * vec_p_2 / 3 / (1 + delta_k_1) / 3 / (1 + delta_k_2),dim=1)
output3 = torch.sum((self.idx3.view(1,3) + eta3) * vec_p_3 / 3 / (1 + delta_k_1) / 3 / (1 + delta_k_2) / 3 /(1 + delta_k_3),dim=1)
print(output1,output2,output2)
return (output1 + output2 + output3) * 101
if __name__ == "__main__":
from torchsummary import summary
model = SSRNet()
#summary(model,(3,64,64),device = "cpu")
x = torch.randn(2,3,64,64)
print(model(x))