推荐系统 和 Torch-RecHub 概述

推荐系统是缓解互联网等环境下信息过载问题的重要方法之一,在电商、电影和音乐等很多消费场景发挥了重要价值。然而,随着业务场景的不断更新和提出新的要求,推荐系统的设计也必须要不断改进。作为研究者或者工程师,有必要掌握一套好的工具来实现和改进推荐系统相关的架构和模型,提升科研和工作效率。

Torch-RecHub 是一个轻量级的pytorch推荐模型框架,易用易拓展,聚焦复现业界实用的推荐模型,以及泛生态化的推荐场景,是值得学习和掌握的。而且,Torch-RecHub 有优秀的团队一直在维护和改进,并且有良好的社群可以学习和讨论。

本文将简要介绍推荐系统相关的模块和技术,以及 Torch_RecHub 的快速使用方式。后续的文章会详细介绍一些具体的前沿模型原理和使用方式。

1. 推荐系统

1.1 价值和应用场景(应用层)

随着移动互联网的飞速发展,人们已经处于一个信息过载(VS 信息不足)的时代。在这个时代中,信息的生产者很难将信息呈现在对它们感兴趣的信息消费者面前,而对于信息消费者也很难从海量的信息中找到自己感兴趣的信息(搜寻成本过高)。

目前处理信息过载问题的主要方式是搜索(主动查询,例如百度搜索)和推荐。不同于搜索,推荐系统可以根据用户的历史行为记录推测出潜在的兴趣并推荐对应的商品(类似于商场销售员的推荐,要比自己挑选更能提高效率和成交率)。具体来说,对于搜索、推荐、广告这三个领域的区别和联系可以参考 王喆老师写的排得更好VS估得更准VS搜的更全「推荐、广告、搜索」算法间到底有什么区别?

推荐系统实施后,平台能够最大限度地吸引用户、留存用户、增加用户粘性、提高用户转化率,从而达到平台商业目标增长的目的。对于信息消费者(用户),推荐系统可以降低自己的搜寻成本、提高自己的购物体验、发现一些意想不到的好商品。对于信息生产者(物品或商家),推荐系统给予了那些长尾或冷门商品曝光的机会,而不是被那些热门信息或商品所淹没,有助于激励那些中小商家或信息生产者提供多样性的物品或内容。

推荐的场景也非常多,就像广告一样,几乎无孔不入,比如 APP打开后的首页推荐(淘宝、京东、拼多多)、消费指定商品时的关联推荐(抖音、快手、B站、爱奇艺等的视频推荐;网易云音乐等的音乐推荐;知乎、豆瓣等的资讯推荐)。根据推荐的对象、推荐的环节和推荐目标的不同,所适用的推荐架构及其技术都可能是不一样的。因此,为了成为优秀的推荐系统工程师或者科学家,必须同时掌握业务场景和工程技术两方面的知识。

1.2 架构和模块组合(架构层)

推荐系统的核心任务是从海量物品中找到用户感兴趣的内容。在这个背景下,推荐系统包含的模块非常多,每个模块将会有很多专业研究的工程和研究工程师。俗话说,不谋全局者,不足以谋一域。因此,我们有必要梳理清楚整个推荐系统的架构,知道每一个部分需要完成哪些任务,是如何做的,主要的技术栈是什么,有哪些局限和可以研究的问题。本文将会从系统架构算法架构两个角度出发解析推荐系统通用架构。

1.2.1 系统架构

系统架构的设计思想是大数据背景下如何有效利用海量和实时数据,将推荐系统按照对数据利用情况和系统响应要求出发,将整个架构分为离线层近线层在线层三个模块,更多的是考虑推荐算法在工程技术实现上的问题。下图是 Netflix在2013年提出的架构,其中的技术也许不是现在常用的技术了,但是架构模型仍然被很多公司采用。

简单来说,各层的功能如下所示,本文之后会详细介绍:

  • 在线层:使用实时数据,保证实时在线服务;将用户在平台上真实的行为记录下来,比如用户看到了哪些内容,和哪些内容发生了交互,和哪些没有发生了交互。如果再精细一点,还会记录用户停留的时间,用户使用的设备等等。主要是后端和客户端完成。
  • 近线层:使用实时数据,不保证实时响应;记录一些准实时的数据,比如用户之前还看过哪些内容,和哪些内容发生过交互。理想情况这部分数据也需要做成实时,但由于数据量比较大、逻辑相对复杂,一般都是通过消息队列加在线缓存的方式做成准实时。
  • 离线层:不用实时数据,不提供实时响应;数据处理的大头,但对时间和算法复杂等基本没有要求。所有“脏活累活”复杂的操作都是在离线完成的,比如训练一些会长期使用的大模型。

离线层

离线层是计算量最大的一个部分,需要实现的主要功能模块是:数据处理、数据存储;特征工程、离线特征计算;离线模型的训练。目前主流的做法是HDFS,收集到我们所有的业务数据,通过HIVE等工具,从全量数据中抽取出我们需要的数据,进行相应的加工,离线阶段主流使用的分布式框架一般是Spark,如下图所示:

离线任务一般会按照天或者更久运行,比如每天晚上定期更新这一天的数据,然后重新训练模型,第二天上线新模型。

离线层面临的数据量级是最大的,面临主要的问题是海量数据存储、大规模特征工程、多机分布式机器学习模型训练。

离线层的优势:可以处理大量的数据,进行大规模特征工程;可以进行批量处理和计算;不用有响应时间要求。
离线层的劣势:无法反应用户的实时兴趣变化。

近线层

近线层的产生是同时想要弥补离线层和在线层的不足,折中的产物。它可以获得实时数据,然后快速计算提供服务,但是并不要求它和在线层一样达到几十毫秒这种延时要求。

它适合处理一些对延时比较敏感的任务,比如:

  • 特征的实时更新计算:推荐系统一个老生常谈的问题就是特征分布不一致怎么办,如果使用离线算好的特征就容易出现这个问题。近线层能够获取实时数据,按照用户的实时兴趣计算就能很好避免这个问题。
  • 实时训练数据的获取:比如在使用DIN、DSIN这行网络会依赖于用户的实时兴趣变化,用户几分钟前的点击就可以通过近线层获取特征输入模型。
  • 模型实时训练:可以通过在线学习的方法更新模型,实时推送到线上;

近线层的发展得益于最近几年大数据技术的发展,很多流处理框架的提出大大促进了近线层的进步。如今Flink、Storm等工具一统天下。

在线层

所有的用户请求都会发送到在线层,在线层需要快速返回结果,它主要承担的工作有:

  • 模型在线服务:包括了快速召回和排序;
  • 在线特征快速处理拼接:根据传入的用户ID和场景,快速读取特征和处理;
  • AB实验或者分流:根据不同用户采用不一样的模型,比如冷启动用户和正常服务模型;
  • 运筹优化和业务干预:比如要对特殊商家流量扶持、对某些内容限流;

在线服务的数据源就是我们在离线层计算好的每个用户和商品特征,我们事先存放在数据库中,在线层只需要实时拼接,不进行复杂的特征运算,然后输入近线层或者离线层已经训练好的模型,根据推理结果进行排序,最后返回给后台服务器,后台服务器根据我们对每一个用户的打分,再返回给用户。

在线层最大的问题就是对实时性要求特别高,一般来说是几十毫秒,这就限制了我们能做的工作,很多任务往往无法及时完成,需要近线层协助我们做。

1.2.2 算法架构

在实际的工业场景中,不管是用户维度、物品维度还是用户和物品的交互维度,数据都是极其丰富的,学术界对算法的使用方法不能照搬到工业界。当一个用户访问推荐模块时,系统不可能针对该用户对所有的物品进行排序(学术界做法)。所以一个通用的算法架构,设计思想就是对数据层层建模,层层筛选,帮助用户从海量数据中找出其真正感兴趣的部分,如下图所示:

召回

召回层的主要目标时从推荐池中选取几千上万的item,送给后续的排序模块。由于召回面对的候选集十分大,且一般需要在线输出,故召回模块必须轻量快速低延迟。由于后续还有排序模块作为保障,召回不需要十分准确,但不可遗漏。

目前基本上采用多路召回解决范式(一方面各路可以并行计算,另一方面取长补短),分为非个性化召回和个性化召回。个性化召回又有content-based、behavior-based、feature-based等多种方式。

召回主要考虑的内容有:

  • 考虑用户层面:用户兴趣的多元化,用户需求与场景的多元化:例如:新闻需求,重大要闻,相关内容沉浸阅读等等
  • 考虑系统层面:增强系统的鲁棒性;部分召回失效,其余召回队列兜底不会导致整个召回层失效;排序层失效,召回队列兜底不会导致整个推荐系统失效
  • 系统多样性内容分发:图文、视频、小视频;精准、试探、时效一定比例;召回目标的多元化,例如:相关性,沉浸时长,时效性,特色内容等等
  • 可解释性推荐一部分召回是有明确推荐理由的:很好的解决产品性数据的引入;

粗排

粗排的原因是有时候召回的结果还是太多,精排层速度还是跟不上,所以加入粗排。粗排可以理解为精排前的一轮过滤机制,减轻精排模块的压力。粗排介于召回和精排之间,要同时兼顾精准性和低延迟。

粗排阶段的架构设计主要是考虑三个方面,一个是根据精排模型中的重要特征,来做候选集的截断,另一部分是有一些召回设计,比如热度或者语义相关的这些结果,仅考虑了item侧的特征。最后是算法的选型要在在线服务的性能上有保证,因为这个阶段在pipeline中完成从召回到精排的截断工作,在延迟允许的范围内能处理更多的召回候选集理论上与精排效果正相关。

精排

精排需要在最大时延允许的情况下,保证打分的精准性,是整个系统中至关重要的一个模块,也是最复杂,研究最多的一个模块。

精排和粗排层的基本目标是一致的,但是和粗排不同的是,精排只需要对少量的商品进行排序即可。因此,精排中可以使用比粗排更多的特征,更复杂的模型和更精细的策略(用户的特征和行为在该层的大量使用和参与也是基于这个原因)。

目前精排层深度学习已经一统天下了,精排阶段采用的方案相对通用,首先我们要解决的是样本规模的问题,尽量多的喂给模型去记忆,另一个方面时效性上,用户的反馈产生的时候,怎么尽快的把新的反馈给到模型里去,学到最新的知识。

重排

重排序阶段对精排生成的Top-N个物品的序列进行重新排序,生成一个Top-K个物品的序列,作为排序系统最后的结果,直接展现给用户。重排序的原因是因为多个物品之间往往是相互影响的,而精排序是根据PointWise得分,容易造成推荐结果同质化严重,有很多冗余信息。一般在精排层我们使用AUC作为指标,但是在重排序更多关注NDCG等指标。

重排序在业务中,获取精排的排序结果,还会根据一些策略、运营规则参与排序,比如强制去重、间隔排序、流量扶持等、运营策略、多样性、context上下文等,重新进行一个微调。

由于精排模型一般比较复杂,基于系统时延考虑,一般采用point-wise方式,并行对每个item进行打分。这就使得打分时缺少了上下文感知能力。用户最终是否会点击购买一个商品,除了和它自身有关外,和它周围其他的item也息息相关。重排一般比较轻量,可以加入上下文感知能力,提升推荐整体算法效率。比如三八节对美妆类目商品提权,类目打散、同图打散、同卖家打散等保证用户体验措施。重排中规则比较多,但目前也有不少基于模型来提升重排效果的方案。

混排

多个业务线都想在Feeds流中获取曝光,则需要对它们的结果进行混排。比如推荐流中插入广告、视频流中插入图文和banner等。可以基于规则策略(如广告定坑)和强化学习来实现。

1.3 算法和核心技术(算法层)

本文将会分别从算法和工程两个角度出发,按照推荐系统的环节来介绍当前主流的一些推荐算法技术栈。

1.3.1 画像层

这个环节是推荐系统架构的基础设施,一般可能新用户/商品进来,或者每周定期会重新一次整个物料库,计算其中信息,为用户打上标签,计算统计信息,为商品做内容理解等内容。下面是微信看一看的内容画像框架,可以看到涉及了很多技术。

  • 文本理解:包括item的标题、正文、OCR、评论等数据。简单的算法是提取一些关键词来表征(TF-IDF、Bert、LSTM-CRF等),考虑到近义词等语义关系,目前流行用嵌入式表征(RNN、TextCNN、FastText、Bert等),也可训练模型做文本分类;
  • 多模态融合:在很多场景下,推荐的主题都是视频或者图片,远远多于仅推荐文本的情况,需要从视频和图片文本及其元数据来抽取更多的信息(TSN、RetinaFace、PSENet等)
  • 知识图谱:为知识承载系统,用于对接内外部关键词信息与词关系信息;内容画像会将原关系信息整合,并构建可业务应用的关系知识体系(KGAT、RippleNet等)

1.3.2 召回/粗排

召回一般都是多路召回,从模型角度分析有很多召回算法,这种一般是在召回层占大部分比例点召回。除此之外,还会有探索类召回、策略运营类召回、社交类召回等。接下来我们着重介绍模型类召回。

经典模型召回

通过某种算法,对 user 和 item 分别打上 Embedding,然后 user 与 item 在线进行 KNN 计算实时查询最近邻结果作为召回结果,快速找出匹配的物品。需要注意的是如果采用模型召回方法,优化目标最好和排序的优化目标一致,否则可能被过滤掉。

在这方面典型的算法有:FM、双塔DSSM、Multi-View DNN等。

序列模型召回

用户在使用 APP 或者网站的时候,一般会产生一些针对物品的行为,比如点击一些感兴趣的物品,收藏或者互动行为,或者是购买商品等。而一般用户之所以会对物品发生行为,往往意味着这些物品是符合用户兴趣的,而不同类型的行为,可能代表了不同程度的兴趣。比如购买就是比点击更能表征用户兴趣的行为。在召回阶段,如何根据用户行为序列打 embedding,可以采取有监督的模型,比如 Next Item Prediction 的预测方式即可;也可以采用无监督的方式,比如物品只要能打出 embedding,就能无监督集成用户行为序列内容,例如 Sum Pooling。

这方面典型的算法有:CBOW、Skip-Gram、GRU、Bert等。

用户序列拆分

用户往往是多兴趣的,比如可能同时对娱乐、体育、收藏感兴趣。这些不同的兴趣也能从用户行为序列的物品构成上看出来,比如行为序列中大部分是娱乐类,一部分体育类,少部分收藏类等。用户多兴趣拆分可以更细致刻画用户兴趣的方向。

本质上,把用户行为序列打到多个 embedding 上,实际它是个类似聚类的过程,就是把不同的 Item,聚类到不同的兴趣类别里去。目前常用的拆分用户兴趣 embedding 的方法,主要是胶囊网络和 Memory Network,但是理论上,很多类似聚类的方法应该都是有效的,所以完全可以在这块替换成你自己的能产生聚类效果的方法来做。

这方面典型的算法有:Multi-Interest Network with Dynamic Routing for Recommendation at Tmall等。

知识图谱

知识图谱有一个独有的优势和价值,那就是对于推荐结果的可解释性:比如推荐给用户某个物品,可以在知识图谱里通过物品的关键关联路径给出合理解释,这对于推荐结果的解释性来说是很好的。缺点在于,在排序角度来看,是效果最差的一类方法。所以往往可以利用知识图谱构建一条可解释性的召回通路。

这方面的算法有:KGAT、RippleNet等。

图模型

推荐系统中User和Item相关的行为、需求、属性和社交信息具有天然的图结构,可以使用一张复杂的异构图来表示整个推荐系统。图神经网络模型推荐就是基于这个想法,把异构网络中包含的结构和语义信息编码到结点Embedding表示中,并使用得到向量进行个性化推荐。

这方面典型的算法有:GraphSAGE、PinSage等。

1.3.3 精排

排序模型是推荐系统中涵盖的研究方向最多,有非常多的子领域值得研究探索,这也是推荐系统中技术含量最高的部分,毕竟它是直接面对用户,产生的结果对用户影响最大的一层。目前精排层深度学习已经一统天下了,下图是王喆老师《深度学习推荐算法》书中的精排层模型演化线路。

特征交叉模型

在深度学习推荐算法发展早期,很多论文聚焦于如何提升模型的特征组合和交叉的能力,这其中既包含隐式特征交叉Deep Crossing也有采用显式特征交叉的探究。本质上是希望模型能够摆脱人工先验的特征工程,实现端到端的一套模型。

这方面的经典研究工作有:DCN、DeepFM、xDeepFM等。

序列模型

在序列建模中,主要任务目标是得到用户此刻的兴趣向量(user interest vector)。如何刻画用户兴趣的广泛性,是推荐系统比较大的一个难点,用户历史行为序列建模的研究经历了从Pooling、RNN到attention、capsule再到transformer的顺序。在序列模型中,又有很多细分的方向,比如根据用户行为长度有研究用户终身行为序列的,也有聚焦当下兴趣的,还有研究如何抽取序列特征的抽取器,比如研究attention还是胶囊网络。

这方面典型的研究工作有:DIN、DSIN、DIEN、SIM等。

多模态信息融合

传统做法在多模态信息融合就是希望把不同模态信息利用起来,通过Embedding技术融合进模型。在推荐领域,主流的做法还是一套非端到端的体系,由其他模型抽取出多模态信息,推荐只需要融合入这些信息就好了。同时也有其他工作是利用注意力机制等方法来学习不同模态之间的关联,来增强多模态的表示。

比较典型的工作有:Image Matters: Visually modeling user behaviors using Advanced Model Server、UMPR等。

多任务学习

很多场景下我们模型优化的目标都是CTR,有一些场景只考虑CTR是不够的,点击率模型、时长模型和完播率模型是大部分信息流产品推荐算法团队都会尝试去做的模型。单独优化点击率模型容易推出来标题党,单独优化时长模型可能推出来的都是长视频或长文章,单独优化完播率模型可能短视频短图文就容易被推出来,所以多目标就应运而生。这些概率实际上就是模型要学习的目标,多种目标综合起来,包括阅读、点赞、收藏、分享等等一系列的行为,归纳到一个模型里面进行学习,这就是推荐系统的多目标学习。

这方面比较典型的算法有:ESSM、MMoE、DUPN等。

强化学习

先强化学习能够比较灵活地定义优化的业务目标,考虑推荐系统长短期的收益,比如用户留存。在深度模型下,我们很难设计这个指标的优化函数,而强化学习是可以对长期收益下来建模。第二是能够体现用户兴趣的动态变化,比如在新闻推荐下,用户兴趣变化很快,强化学习更容易通过用户行为动态产生推荐结果。最后是EE也就是利用探索机制,这种一种当前和长期收益的权衡,强化学习能够更好的调节这里的回报。

这方面比较典型的算法有:DQN、Reinforcement Learning for Slate-based Recommender Systems: A Tractable Decomposition and Practical Methodology。

跨域推荐

一般一家公司业务线都是非常多的,比如腾讯既有腾讯视频,也有微信看一看、视频号,还有腾讯音乐,如果能够结合这几个场景的数据,同时进行推荐,一方面对于冷启动是非常有利的,另一方面也能补充更多数据,更好的进行精确推荐。

跨域推荐系统相比一般的推荐系统要更加复杂。我们要关心在不同领域间要选择何种信息进行迁移,以及如何迁移这些信息。

这方面典型的模型有:DTCDR、MV-DNN、EMCDR等。

1.3.4 重排

重排序在业务中会根据一些策略、运营规则参与排序,比如强制去重、间隔排序、流量扶持等,但是总计趋势上看还是算法排序越来越占据主流趋势。重排序更多的是List Wise作为优化目标的,它关注的是列表中商品顺序的问题来优化模型,但是一般List Wise因为状态空间大,存在训练速度慢的问题。这方面典型的做法,基于RNN、Transformer、强化学习的都有。

这里的经典算法有:MRR、DPP、RNN等。

1.3.5 工程

推荐系统的实现需要依托工程,很多研究界Paper的idea满天飞,却忽视了工业界能否落地,进入工业界我们很难或者很少有组是做纯research的,所以我们同样有很多工程技术需要掌握。下面列举了在推荐中主要用到的工程技术:

  • 编程语言:Python、Java(scala)、C++、sql、shell;
  • 机器学习:Tensorflow/Pytorch、GraphLab/GraphCHI、LGB/Xgboost、SKLearn;
  • 数据分析:Pandas、Numpy、Seaborn、Spark;
  • 数据存储:mysql、redis、mangodb、hive、kafka、es、hbase;
  • 相似计算:annoy、faiss、kgraph
  • 流计算:Spark Streaming、Flink
  • 分布式:Hadoop、Spark

最重要的就是加粗的三部分,第一是语言:必须掌握的是Python,C++和JAVA中根据不同的组使用的是不同的语言,这个如果没有时间可以等进组后慢慢学习。然后是机器学习框架:Tensorflow和Pytorch至少要掌握一个,前期不用纠结学哪个,这个迁移成本很低,基本能够达到触类旁通,而且面试官不会为难你只会这个不会那个。最后是数据分析工具:Pandas是我们处理单机规模数据的利器,但是进入工业界,Hadoop和Saprk是需要会用的,不过不用学太深,会用即可。

2. Torch-RecHub

Torch-RecHub是一个轻量级的pytorch推荐模型框架,可以用来做工程实践和研究。本质上就一个基于PyTorch的Python库,可以下载调用。这个库目前还在由开发团队不断拓展中,感兴趣的也可以申请加入团队。

2.1 安装

pip install torch-rechub

在作者安装的时候,出现了一些错误,可能是因为 Windows 操作系统的原因。

BUG1:显示未发现 nose 库

解决方案:

pip install nose  # 版本需要 > 1.0.0
pip install --upgrade nose  # 若版本不够,可更新

BUG2: 显示安装 annoy 库失败,缺乏 C++ 环境

解决方案:安装 C++ 后再重新安装 torch-rechub,详情可参考 Microsoft Visual C++ 14.0 is required解决方法

PS:手动下载 annoy 离线安装可能会失败,显示缺乏依赖库的安装(手动安装,不会自动安装所依赖的库?)。推荐安装 C++,一劳永逸。

2.2 框架结构

2.3 简单案例(精排 CTR预测)

以Amazon-Electronics为例,原数据是json格式,需要预处理为一个仅包含user_id, item_id, cate_id, time四个特征列的CSV文件。可通过高速网盘链接下载:https://cowtransfer.com/s/e911569fbb1043

案例中所使用的模型是 Deep Interest Network(DIN). 整体的模型架构如下所示,可以初步了解一下输入和输出的类型,便于调用。

2.3.1 数据读取

file_path = '../examples/ranking/data/amazon-electronics/amazon_electronic_datasets.csv' # 需要将下载的文件当在 file_path 中
data = pd.read_csv(file_path)
data = data.iloc[:10000,]  # 原数据集太大,有169万行,本文只取前一万行
data

其中,cate_id 指 item 所属的 category 的唯一标识。time 的值越大,时间就越晚,可以转化为 Date 等数据对象。这些数据可以告诉我们,哪个 user 在什么时间点击了哪一个 item,这些数据可用来训练点击率预测模型。

2.3.2 特征工程

深度学习模型基本都需要有输入和输出来进行训练,DIN也不例外。因此,我们需要将原始数据转换为模型要求的输入类型(特征工程),才能进一步调用。

from torch_rechub.utils.data import create_seq_features
# 构建用户的历史行为序列特征,只需要指定数据,和需要生成序列的特征
# drop_short是选择舍弃行为序列较短的用户(默认为3)
# max_len是选择序列的最大长度(截断),若不足则补0,默认为50

train, val, test = create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], drop_short=0)
# 查看当前构建的序列,在这个案例中我们创建了历史点击序列,和历史类别序列
# 该命令也生成了很多负样本,用于训练。要知道,原始数据中只包括正样本(未被点击的项目不被记录)

train # 与 val,test 一样都是 Dataframe 格式

不同类别的特征要使用不同的处理方式。在这个案例中,因为我们使用user_id,item_id和item_cate这三个类别特征,使用用户的item_id和cate的历史序列作为序列特征。在torch-rechub我们只需要调用DenseFeature, SparseFeature, SequenceFeature这三个类,就能让模型自动正确处理每一类特征。

from torch_rechub.basic.features import DenseFeature, SparseFeature, SequenceFeature

# 获得每个类别的数量,用来指定 vocab_size 的大小
n_users, n_items, n_cates = data["user_id"].max(), data["item_id"].max(), data["cate_id"].max()

# 对于sparsefeature,需要输入embedding层,所以需要指定特征空间大小和输出的维度
features = [SparseFeature("target_item", vocab_size=n_items + 2, embed_dim=8),
            SparseFeature("target_cate", vocab_size=n_cates + 2, embed_dim=8),
            SparseFeature("user_id", vocab_size=n_users + 2, embed_dim=8)]
target_features = features

# 对于序列特征,除了需要和类别特征一样处理意外,item序列和候选item应该属于同一个空间,我们希望模型共享它们的embedding,所以可以通过shared_with参数指定
history_features = [
    SequenceFeature("history_item", vocab_size=n_items + 2, embed_dim=8, pooling="concat", shared_with="target_item"),
    SequenceFeature("history_cate", vocab_size=n_cates + 2, embed_dim=8, pooling="concat", shared_with="target_cate")
]

2.3.3 实验数据载入

数据预处理后,获得特征后,就要按照深度学习的训练测试范式来构建批处理的 dataloader。

from torch_rechub.utils.data import df_to_dict, DataGenerator

# 指定label,生成模型的输入,这一步是转换为字典结构
train = df_to_dict(train)
val = df_to_dict(val)
test = df_to_dict(test)

train_y, val_y, test_y = train["label"], val["label"], test["label"]

# 删除 y
del train["label"]
del val["label"]
del test["label"]
train_x, val_x, test_x = train, val, test

# 最后查看一次输入模型的数据格式
print(train_x)

# 构建dataloader,指定模型读取数据的方式,和区分验证集测试集、指定batch大小
dg = DataGenerator(train_x, train_y)
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(x_val=val_x, y_val=val_y, x_test=test_x, y_test=test_y, batch_size=16)

可以看到,train_x 是词典的数据类型。

2.3.4 训练新模型

from torch_rechub.models.ranking import DIN
from torch_rechub.trainers import CTRTrainer

# 定义模型,模型的参数需要我们之前的feature类,用于构建模型的输入层,mlp指定模型后续DNN的结构,attention_mlp指定attention层的结构
model = DIN(features=features, history_features=history_features, target_features=target_features, mlp_params={"dims": [256, 128]}, attention_mlp_params={"dims": [256, 128]})

# 模型训练,需要学习率、设备等一般的参数,此外我们还支持earlystoping策略,及时发现过拟合
ctr_trainer = CTRTrainer(model, optimizer_params={"lr": 1e-3, "weight_decay": 1e-3}, n_epoch=3, earlystop_patience=4, device='cpu', model_path='./')
ctr_trainer.fit(train_dataloader, val_dataloader)

# 查看在测试集上的性能
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
print(f'test auc: {auc}')

参考资料

posted @ 2022-06-13 15:49  Junwei_Kuang  阅读(1255)  评论(0编辑  收藏  举报