深度排序模型概述(二)PNN/NFM/AFM

在CTR预估中,为了解决稀疏特征的问题,学者们提出了FM模型来建模特征之间的交互关系。但是FM模型只能表达特征之间两两组合之间的关系,无法建模两个特征之间深层次的关系或者说多个特征之间的交互关系,因此学者们通过Deep Network来建模更高阶的特征之间的关系。
因此,FM和深度网络DNN的结合也就成为了CTR预估问题中主流的方法。有关FM和DNN的结合有两种主流的方法,并行结构和串行结构。两种结构的理解以及实现如下表所示:

结构 描述 常见模型
并行结构 M部分和DNN部分分开计算,只在输出层进行一次融合得到结果 DeepFM,DCN,Wide&Deep
串行结构 将FM的一次项和二次项结果(或其中之一)作为DNN部分的输入,经DNN得到最终结果 PNN,NFM,AFM

PNN

PNN,全称为Product-based Neural Network,认为在embedding输入到MLP之后学习的交叉特征表达并不充分,提出了一种product layer的思想,即基于乘法的运算来体现体征交叉的DNN网络结构,如下图:

PNN网络结构

PNN包括三层:Embedding Layer、Product Layer、Full-connect Layer。
按照论文的思路,我们也从上往下来看这个网络结构:

  • 输出层
    输出层很简单,将上一层的网络输出通过一个全链接层,经过sigmoid函数转换后映射到(0,1)的区间中,得到我们的点击率的预测值:

y^=σ(W3l2+b3)

  • l2层
    根据l1层的输出,经一个全链接层 ,并使用relu进行激活,得到我们l2的输出结果:

l2=relu(W2l1+b2)

  • l1层
    l1层的输出由如下的公式计算:

l1=relu(lz+lp+b1)

其中,b1是偏置项,l_z,l_p由Product Layer计算得来

  • Product Layer
    product思想来源于,在ctr预估中,认为特征之间的关系更多是一种and“且”的关系,而非add"加”的关系。例如,性别为男且喜欢游戏的人群,比起性别男和喜欢游戏的人群,前者的组合比后者更能体现特征交叉的意义。
    product layer可以分成两个部分,一部分是线性部分lz,一部分是非线性部分lp。

lzn=Wznzlpi=Wpnp

其中,z是线性信号向量,可以认为z就是embedding层的复制。对于p来说,这里需要一个公式进行映射:

                               p={pi,j},   i=1...N,j=1...Npi,j=g(fi,fj)

不同的g的选择使得我们有了两种PNN的计算方法,一种叫做Inner PNN,简称IPNN,一种叫做Outer PNN,简称OPNN。

  • Embedding Layer
    Embedding Layer跟DeepFM中相同,将每一个field的特征转换成同样长度的向量

  • 损失函数
    使用和逻辑回归同样的损失函数,如下:

L(y,y^)=ylogy^(1y)log(1y^)

接下来,我们分别来具体介绍IPNN和OPNN,由于涉及到复杂度的分析,所以我们这里先定义Embedding的大小为M,field的大小为N,而lz和lp的长度为D1。

IPNN

IPNN的示意图如下:

IPNN网络结构

IPNN中p的计算方式如下,即使用内积来代表pij:

g(fi,fj)=<fi,fj>

所以,pij其实是一个数,得到一个pij的时间复杂度为M,p的大小为NN,因此计算得到p的时间复杂度为NNM。而再由p得到lp的时间复杂度是NND1。因此 对于IPNN来说,总的时间复杂度为NN(D1+M)。文章对这一结构进行了优化,可以看到,我们的p是一个对称矩阵,因此我们的权重也可以是一个对称矩阵,对称矩阵就可以进行如下的分解:

Wpn=θnθnT

最终,我们的权重只需要D1N就可以了,时间复杂度也变为了D1MN

OPNN

OPNN的示意图如下:
OPNN网络结构

OPNN中p的计算方式如下:

pi,j=g(fi,fj)=fifjT

此时pijMM的矩阵,计算一个pij的时间复杂度为MM,而p是NNMM的矩阵,因此计算p的事件复杂度为NNMM。从而计算lp的时间复杂度变为D1NNMM。这个显然代价很高的。为了减少负责度,论文使用了叠加的思想,它重新定义了p矩阵:

p=i=1Nj=1NfifjT=f(f)T,  f=i=1Nfi

这里计算p的时间复杂度变为了D1M(M+N)

NFM

NFM模型(Neural Factorization Machine),是串行结构中一种较为简单的网络模型。
首先来回顾一下FM模型,FM模型用n个隐变量来刻画特征之间的交互关系。这里要强调的一点是,n是特征的总数,是one-hot展开之后的,比如有三组特征,两个连续特征,一个离散特征有5个取值,那么n=7而不是n=3。

FM模型表达式:

y^FM(x)=w0+i=1nwixi+i=1nj=i+1nvi,vjxixj

顺便回顾一下化简过程:

i=1nj=i+1nvi,vjxixj(1)=12i=1nj=1nvi,vjxixj12i=1nvi,vixixi(2)=12(i=1nj=1nf=1kvi,fvj,fxixji=1nf=1kvi,fvi,fxixi)(3)=12f=1k(i=1nvi,fxi)(j=1nvj,fxj)i=1nvi,f2xi2(4)=12f=1k(i=1nvi,fxi)2i=1nvi,f2xi2(5)

可以看到,不考虑最外层的求和,我们可以得到一个K维的向量。

对于NFM模型,目标值的预测公式变为:

y^NFM(x)=w0+i=1nwixi+f(x)

其中,f(x)是用来建模特征之间交互关系的多层前馈神经网络模块,架构图如下所示:

NFM网络结构

Embedding Layer和我们之前几个网络是一样的,embedding 得到的vector其实就是我们在FM中要学习的隐变量v。

Bi-Interaction Layer,其实就是计算FM中的二次项的过程,因此得到的向量维度就是我们的Embedding的维度。最终的结果是:

fBI(vx)=12[(i=1nxivi)2i=1n(xivi)2]

Hidden Layers就是我们的DNN部分,将Bi-Interaction Layer得到的结果接入多层的神经网络进行训练,从而捕捉到特征之间复杂的非线性关系。

在进行多层训练之后,将最后一层的输出求和同时加上一次项和偏置项,就得到了我们的预测输出:

y^NFM(x)=w0+i=1nwixi+hTσL(WL(...σ1(W1fBI(Vx)+b1)...)+bL)

NFM模型将FM与神经网络结合以提升FM捕捉特征间多阶交互信息的能力。根据论文中实验结果,NFM的预测准确度相较FM有明显提升,并且与现有的并行神经网络模型相比,复杂度更低。

NFM本质上还是基于FM,FM会让一个特征固定一个特定的向量,当这个特征与其他特征做交叉时,都是用同样的向量去做计算。这个是很不合理的,因为不同的特征之间的交叉,重要程度是不一样的。因此,学者们提出了AFM模型(Attentional factorization machines),将attention机制加入到我们的模型中。

AFM

AFM只是在FM的基础上添加了attention的机制。关于什么是attention model?本文不打算详细赘述,我们这里只需要知道的是,attention机制相当于一个加权平均,attention的值就是其中权重,判断不同特征之间交互的重要性。
attention相等于加权的过程,因此我们的预测公式变为:

y^AFM(x)=w0+i=1nwixi+pTi=1nj=i+1naij(vivj)xixj

圆圈中有个点的符号代表的含义是element-wise product,即:

(x1,1,x1,2...)(x2,1,x2,2...)=(x1,1x2,1,x1,2x2,2,...)

因此,我们在求和之后得到的是一个K维的向量,还需要跟一个向量p相乘,得到一个具体的数值。

可以看到,AFM的前两部分和FM相同,后面的一项经由如下的网络得到:
AFM网络结构

图中的前三部分:sparse iput,embedding layer,pair-wise interaction layer,都和FM是一样的。而后面的两部分,则是AFM的创新所在,也就是我们的Attention net。Attention背后的数学公式如下:

aij=hTRelu(W(vivj)xixj+b)aij=exp(aij)i,jexp(aij)

总结一下,不难看出AFM只是在FM的基础上添加了attention的机制,但是实际上,由于最后的加权累加,二次项并没有进行更深的网络去学习非线性交叉特征,所以AFM并没有发挥出DNN的优势,也许结合DNN可以达到更好的结果。

posted @   Jamest  阅读(3272)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示