【推荐算法】Wide & Deep

Wide & Deep主要解决了FM的以下几个痛点:

  1. 强化记忆能力。FM这类embedding类模型拥有强大的泛化能力,在embedding的过程中传入了大量的全局信息,对于一些很少出现甚至没有出现过的特征组合,也能计算出合理的特征组合权重。但是,当共现矩阵过于稀疏时,模型会过分泛化,推荐出很多相关性低的内容。Wide & Deep不仅强调了模型的“泛化能力”,还强调了“记忆能力”。具体来说,记忆能力表示一类规则式的推荐,即当出现情况A时,就推荐B。

算法

如图,模型由左边的wide与右边的deep结合而成,两者的结果相加后经过一个sigmoid层得到最终的CTR预估值。其中,deep中包含所有特征,wide中包含人工设计的需要加强记忆能力的特征交叉(如下载软件A与展现软件B),特征交叉方法使用交叉积,可以看做特征one-hot表示下的向量内积。

模型输入

首先定义一个mlp实现deep部分:

class MultiLayerPerceptron(torch.nn.Module):
    def __init__(self, input_dim, embed_dims, dropout):
        super().__init__()
        layers = list()
        for embed_dim in embed_dims:
            layers.append(torch.nn.Linear(input_dim, embed_dim))
            layers.append(torch.nn.BatchNorm1d(embed_dim))
            layers.append(torch.nn.ReLU())
            layers.append(torch.nn.Dropout(p=dropout))
            input_dim = embed_dim
        layers.append(torch.nn.Linear(input_dim, 1))
        self.mlp = torch.nn.Sequential(*layers)

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, embed_dim)``
        """
        return self.mlp(x)

为了简化,我们使用全量物品特征加入wide部分,并且不手工设计特征交叉。代码中使用LR的实现来实现wide部分(很不严谨)。最后,构建完整的DeepFM前向传播链路:

class WideAndDeepModel(torch.nn.Module):
    def __init__(self, field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2):
        super().__init__()
        self.linear = FeaturesLinear(field_dims)
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.embed_output_dim = len(field_dims) * embed_dim
        self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        embed_x = self.embedding(x)
        x = self.linear(x) + self.mlp(embed_x.view(-1, self.embed_output_dim))
        return torch.sigmoid(x.squeeze(1))

模型效果

设置:
数据集:ml-100k
优化方法:Adam
学习率:0.003

效果:
收敛epoch: 10
train logloss: 0.52100
val auc: 0.78681
test auc: 0.78500

posted @ 2021-06-30 16:32  tmpUser  阅读(161)  评论(0编辑  收藏  举报