Deepctr框架代码阅读

DeepCtr是一个简易的CTR模型框架,集成了深度学习流行的所有模型,适合学推荐系统模型的人参考。我在参加比赛中用到了这个框架,但是效果一般,为了搞清楚原因从算法和框架两方面入手。在读代码的过程中遇到一些不理解的问题,所以记录在这里。

1. DeepFM模型的整体流程

preprocess_input_embedding:
	create_singlefeat_inputdict:
		搞成Feat是为了整体封装好,然后输入到Input的时候可以一一对应
		dense和spare直接放入keras的Input,格式是dict{key是feat名字,value是Input层结果}
	create_varlenfeat_inputdict
		序列是直接用max len放入keras的Input
	get_inputs_embedding
		create_embedding_dict(embedding层)
			稀疏特征:自动指定embeddingsize,6 * int(pow(feat.dimension, 0.25),否则按照指定的embeddingsize,使用L2正则化
			序列特征和稀疏特征的流程一样,封装Embedding多了mask_zero
			结果都是dict{key是feat名字,value是Embedding层结果},处理的是稀疏和序列特征
		get_embedding_vec_list(是embedding值)
			如果指定是hash,用自己写的Hash函数将特征的索引(是Input层的结果,也就是原始是数据输入)转换成hash函数,如果不是就直接用原始的特征索引。用特征索引和对应的特征在Embedding获取输出。
			这里只处理sparse
		merge_sequence_input
			这里处理序列
			get_varlen_embedding_vec_dict
				hash的时候全部填充0,和之前的区别是之前指定的才填充0
				这里和处理的sparse的方式一样,区别是用sequence_input_dict,但是这个和sparse都是用OrderedDict,区别是用最大长度,名字加seq
			get_pooling_vec_list
				如果没有最长长度或者长度序列为空,不填充。在SequencePoolingLayer对序列特征进行pool
			把结果加到之前sparse的结果上
			返回全部结果		
		merge_dense_input
			把原始的embedding拼接成列向量,加到之前的结果上
		如果有线性:
			稀疏向量的embedding长度是1,只做融合稀疏和序列,流程和之前一样
		inputs_list是稀疏、稠密、序列和序列长度
返回deep_emb_list, linear_emb_list, dense_input_dict, inputs_list
get_linear_logit
	如果有linear embedding,就是dense input经过全连接层之后加linear embedding,否则是dense input经过全连接层之后直接输出。默认创造linear embedding
embedding拉平放入FM和Deep中,然后是linear+FM+Deep,如果都有的话,有一个加一个
	FM:和的平方-平方的和
	DNN:默认两层128*128
最后所有的结果放入PredictionLayer,也就是连一个softmax或者sigmoid	

2. 框架优点

  • 整体结构清晰灵活,linear返回logit,FM层返回logit,deep包含中间层结果,在每一种模型中打包deep的最后一层,判断linear,fm和deep是否需要,最后接入全连接层。
  • 主要用到的模块和架构: keras的Concatenate(list转tensor),Dense(最后的全连接层和dense),Embedding(sparse,dense,sequence),Input(sparse,dense,sequce)还有常规操作:优化器,正则化项
  • 复用了重载了Layer层,重写了build,call,compute_output_shape,compute_mask,get_config

3. 框架缺点

  • 给定的参数都是论文提供的参数,实际使用存在问题,都需要自己修改!
  • 好多参数没有留接口,比如回归问题的loss 是mean_squared_error,只能通过硬写来修改参数
  • 如果想实现自己的模型,复用这个框架,需要了解keras,同时改很多接口,时间代价较大。

运行模型,每次结果不一样:
这个属于正常现象,尤其是数据不够充分的情况下,
主要原因是由于Tensorflow底层的多线程运行机制以及一些具有随机性的op和random seed导致的。
如果想让每次运行的结果尽量一致,可以考虑使用CPU运行程序,并且指定单线程运行,同时固定random seed,包括python自身的,Numpy的还有tensorflow的

4. 思考的问题

  1. 为什么获取Embedding的时候要hash?为什么是在这个地方hash?
    慢慢想!
  2. 为什么linear需要全连接+Embedding,为什么默认为Linear加入Embeding操作?
    慢慢想!
  3. FM模型的连续特征是怎么处理的?
    1. 离散化后在输入模型,事实上离散化后的模型更适合工业流水线环境。
    2. embedding一般是表示特征,embedding的话一般不能用来表示连续变量,拿年龄举例来说就是不能让10岁和40岁用同一个vector来表示。所以,还是要做one-hot处理的,也可以改动一下FM的结构,在一阶部分保留原始的连续特征。
    3. w&d是将连续值特征转换成累计分布形式,只针对离散特征去做fm和特征交叉;而它的连续值直接当作embedding向量和离散特征的embedding是拼接起来输入到神经网络里面去的。这里DeepFM模型也是这样做的。
  4. 关于embedding降维的思考:
    比如用户是50个binary特征,广告有100个binary特征,那预测用户是否会点击某条广告: 用fm把这些特征都抽象成10维的embedding,而且只做用户和广告的特征交叉,那把用户侧的embedding对应元素相加,这就压缩成10维了,广告侧也这么做,也变成10维了。这时候衡量用户和广告的相关性就直接拿这两个10的向量内积一下。这样实现降维的目的。
  5. 序列化embedding的方式是pooling的方式
    这是参考youtobe的做法
  6. FM的embedding size通常设置的不大比如4或者8这个量级,但是在深度学习中一般设置的会相对大,比如32、64,128。这是为什么?
    实验决定embeddingsize,有专家说过embeddingsize对最后的结果影响不大,所以只是一般使用128.
  7. 是否能实现有sparse的DeepFM
    为了稀疏输入,可以直接替换Tensor的类型
    ids = tf.SparseTensor(sparse_index, sparse_ids, sparse_shape) values = tf.SparseTensor(sparse_index, sparse_values, sparse_shape)
  8. 在这个帖子中提到embedding look up性能非常低,没有解决方案。
    https://zhuanlan.zhihu.com/p/39774203
  9. concat_fun 这里是concat什么?fm的输入为什么需要concat?
    原来是list,每一行是一个tensor,concat之后是tensor,每一行是tensor
  10. tf.keras.layers.Flatten()(fm_input)
    原先的embedding输入是[d,f,k],deep embedding是[d,f*k]
  11. 这里的实现和我的实现不一样:
    我的linear+interact+deep接入全连接层,将所有的特征接入全连接层, 但是根据根据论文和多家的博客来看,我之前理解的是错误的,正确的应该是 fm logit+deep logit,最后接全连接层。 同时AFM等多个模型都是这么处理的。

5. 学习到的python知识

  1. cls:表示类本身,和静态函数很像,但是比静态函数多一个功能,就是调用一般的函数。
  2. namedtuple:表示可访问属性的tuple
  3. isinstance:判断对象是否是一个已知类型
  4. 使用__import__ + list 导入需要的包
  5. new和init的区别:一个是创建,一个是初始化

这里写框架阅读写的有点乱,下次可以按照以下三个要点来记录:

  1. 输入数据的方式
  2. 对不同数据的处理
  3. 最后的结合
posted @ 2019-08-27 12:21  小小小的程序猿  阅读(3374)  评论(0编辑  收藏  举报
window.onload = function(){ $("#live2dcanvas").attr("style","position: fixed; opacity: 0.7; left: 70px; bottom: 0px; z-index: 1; pointer-events: none;") }