使用LFM(Latent factor model)隐语义模型进行Top-N推荐
最近在拜读项亮博士的《推荐系统实践》,系统的学习一下推荐系统的相关知识。今天学习了其中的隐语义模型在Top-N推荐中的应用,在此做一个总结。
隐语义模型LFM和LSI,LDA,Topic Model其实都属于隐含语义分析技术,是一类概念,他们在本质上是相通的,都是找出潜在的主题或分类。这些技术一开始都是在文本挖掘领域中提出来的,近些年它们也被不断应用到其他领域中,并得到了不错的应用效果。比如,在推荐系统中它能够基于用户的行为对item进行自动聚类,也就是把item划分到不同类别/主题,这些主题/类别可以理解为用户的兴趣。
对于一个用户来说,他们可能有不同的兴趣。就以作者举的豆瓣书单的例子来说,用户A会关注数学,历史,计算机方面的书,用户B喜欢机器学习,编程语言,离散数学方面的书, 用户C喜欢大师Knuth, Jiawei Han等人的著作。那我们在推荐的时候,肯定是向用户推荐他感兴趣的类别下的图书。那么前提是我们要对所有item(图书)进行分类。那如何分呢?大家注意到没有,分类标准这个东西是因人而异的,每个用户的想法都不一样。拿B用户来说,他喜欢的三个类别其实都可以算作是计算机方面的书籍,也就是说B的分类粒度要比A小;拿离散数学来讲,他既可以算作数学,也可当做计算机方面的类别,也就是说有些item不能简单的将其划归到确定的单一类别;拿C用户来说,他倾向的是书的作者,只看某几个特定作者的书,那么跟A,B相比它的分类角度就完全不同了。
显然我们不能靠由单个人(编辑)或team的主观想法建立起来的分类标准对整个平台用户喜好进行标准化。
此外我们还需要注意的两个问题:
- 我们在可见的用户书单中归结出3个类别,不等于该用户就只喜欢这3类,对其他类别的书就一点兴趣也没有。也就是说,我们需要了解用户对于所有类别的兴趣度。
- 对于一个给定的类来说,我们需要确定这个类中每本书属于该类别的权重。权重有助于我们确定该推荐哪些书给用户。
下面我们就来看看LFM是如何解决上面的问题的?对于一个给定的用户行为数据集(数据集包含的是所有的user, 所有的item,以及每个user有过行为的item列表),使用LFM对其建模后,我们可以得到如下图所示的模型:(假设数据集中有3个user, 4个item, LFM建模的分类数为4)
R矩阵是user-item矩阵,矩阵值Rij表示的是user i 对item j的兴趣度,这正是我们要求的值。对于一个user来说,当计算出他对所有item的兴趣度后,就可以进行排序并作出推荐。LFM算法从数据集中抽取出若干主题,作为user和item之间连接的桥梁,将R矩阵表示为P矩阵和Q矩阵相乘。其中P矩阵是user-class矩阵,矩阵值Pij表示的是user i对class j的兴趣度;Q矩阵式class-item矩阵,矩阵值Qij表示的是item j在class i中的权重,权重越高越能作为该类的代表。所以LFM根据如下公式来计算用户U对物品I的兴趣度
我们发现使用LFM后,
- 我们不需要关心分类的角度,结果都是基于用户行为统计自动聚类的,全凭数据自己说了算。
- 不需要关心分类粒度的问题,通过设置LFM的最终分类数就可控制粒度,分类数越大,粒度约细。
- 对于一个item,并不是明确的划分到某一类,而是计算其属于每一类的概率,是一种标准的软分类。
- 对于一个user,我们可以得到他对于每一类的兴趣度,而不是只关心可见列表中的那几个类。
- 对于每一个class,我们可以得到类中每个item的权重,越能代表这个类的item,权重越高。
那么,接下去的问题就是如何计算矩阵P和矩阵Q中参数值。一般做法就是最优化损失函数来求参数。在定义损失函数之前,我们需要准备一下数据集并对兴趣度的取值做一说明。
数据集应该包含所有的user和他们有过行为的(也就是喜欢)的item。所有的这些item构成了一个item全集。对于每个user来说,我们把他有过行为的item称为正样本,规定兴趣度RUI=1,此外我们还需要从item全集中随机抽样,选取与正样本数量相当的样本作为负样本,规定兴趣度为RUI=0。因此,兴趣的取值范围为[0,1]。
采样之后原有的数据集得到扩充,得到一个新的user-item集K={(U,I)},其中如果(U,I)是正样本,则RUI=1,否则RUI=0。损失函数如下所示:
上式中的是用来防止过拟合的正则化项,λ需要根据具体应用场景反复实验得到。损失函数的优化使用随机梯度下降算法:
- 通过求参数PUK和QKI的偏导确定最快的下降方向;
- 迭代计算不断优化参数(迭代次数事先人为设置),直到参数收敛。
其中,α是学习速率,α越大,迭代下降的越快。α和λ一样,也需要根据实际的应用场景反复实验得到。本书中,作者在MovieLens数据集上进行实验,他取分类数F=100,α=0.02,λ=0.01。
【注意】:书中在上面四个式子中都缺少了
综上所述,执行LFM需要:
- 根据数据集初始化P和Q矩阵(这是我暂时没有弄懂的地方,这个初始化过程到底是怎么样进行的,还恳请各位童鞋予以赐教。)
- 确定4个参数:分类数F,迭代次数N,学习速率α,正则化参数λ。
LFM的伪代码可以表示如下:
- def LFM(user_items, F, N, alpha, lambda):
- #初始化P,Q矩阵
- [P, Q] = InitModel(user_items, F)
- #开始迭代
- For step in range(0, N):
- #从数据集中依次取出user以及该user喜欢的iterms集
- for user, items in user_item.iterms():
- #随机抽样,为user抽取与items数量相当的负样本,并将正负样本合并,用于优化计算
- samples = RandSelectNegativeSamples(items)
- #依次获取item和user对该item的兴趣度
- for item, rui in samples.items():
- #根据当前参数计算误差 PS:转载的博客中rui写成了eui
- eui = rui - Predict(user, item)
- #优化参数
- for f in range(0, F):
- P[user][f] += alpha * (eui * Q[f][item] - lambda * P[user][f])
- Q[f][item] += alpha * (eui * P[user][f] - lambda * Q[f][item])
- #每次迭代完后,都要降低学习速率。一开始的时候由于离最优值相差甚远,因此快速下降;
- #当优化到一定程度后,就需要放慢学习速率,慢慢的接近最优值。
- alpha *= 0.9
本人对书中的伪代码追加了注释,有不对的地方还请指正。
当估算出P和Q矩阵后,我们就可以使用(*)式计算用户U对各个item的兴趣度值,并将兴趣度值最高的N个iterm(即TOP N)推荐给用户。
总结来说,LFM具有成熟的理论基础,它是一个纯种的学习算法,通过最优化理论来优化指定的参数,建立最优的模型.
========================我是分割线我自豪=============================
我不懂Python, 所以按照书里的步骤用java实现了, 期间走了好多弯路
在这里说一下, 初始值的选择, 在迭代的时候alpha的初始值切记不要选太大了, 我之前一直用0.1, 然后每次都没收敛, 晕死我了, 还一直改代码+调试
以为是其它原因, 浪费了大半天时间
最后不小心把alpha改为0.5后发现就正常了
改时间把代码扔上来
准备下班了........................
-------------------------------------------------我是不需要理由的分割线----------------------------------------------------
时隔好久,为了准备面试,重要把之前的代码复习了一遍,顺便整理好放到Github上
现在就扔到博客上来吧
LFM的核心代码模块
1 package org.juefan.alg; 2 3 import java.text.SimpleDateFormat; 4 import java.util.ArrayList; 5 import java.util.Collections; 6 import java.util.Comparator; 7 import java.util.Date; 8 import java.util.HashMap; 9 import java.util.HashSet; 10 import java.util.List; 11 import java.util.Map; 12 import java.util.Set; 13 14 public class LFM { 15 16 public static final int latent = 100; 17 public static double alpha = 0.03; 18 public static double lambda = 0.01; 19 public static final int iteration = 1; 20 public static final int resys = 10; 21 22 public static Map<Integer, List<Float>> UserMap = new HashMap<Integer, List<Float>>(); 23 public static Map<Integer, List<Float>> ItemMap = new HashMap<Integer, List<Float>>(); 24 25 public static compares compare = new compares(); 26 27 public class State{ 28 public int TemID; 29 public Set<Integer> set = new HashSet<Integer>(); 30 public float sim; 31 32 /**用户集排序*/ 33 public State(Set<Integer> s, float s2){ 34 set.addAll(s); 35 sim = s2; 36 } 37 38 /**Item排序*/ 39 public State(Integer i, float s){ 40 TemID = i; 41 sim = s; 42 } 43 } 44 45 public static class compares implements Comparator<Object>{ 46 @Override 47 public int compare(Object o1, Object o2) { 48 State s1 = (State)o1; 49 State s2 = (State)o2; 50 return s1.sim < s2.sim ? 1:-1; 51 } 52 } 53 54 public String toString(){ 55 return "LFM"; 56 } 57 /** 58 * 加载用户与项目的集合并初始化隐含矩阵 59 * 注意隐含层的数值不能太大,建议在0.05左右 60 * @param user 61 * @param item 62 */ 63 public LFM(Set<Integer> user, Set<Integer> item){ 64 for(Integer u:user){ 65 List<Float> tList = new ArrayList<Float>(); 66 for(int i = 0; i < latent; i++) 67 tList.add((float) ((float) Math.random() * 0.1)); 68 UserMap.put(u, tList); 69 } 70 for(Integer u:item){ 71 List<Float> tList = new ArrayList<Float>(); 72 for(int i = 0; i < latent; i++) 73 tList.add((float) ((float) Math.random() *0.1)); 74 ItemMap.put(u, tList); 75 } 76 } 77 public LFM() { 78 // TODO Auto-generated constructor stub 79 } 80 /** 81 * 计算用户对某个物品的兴趣 82 * @param uLV 用户与隐含类的关系 83 * @param iLV 隐含类与物品的关系 84 * @return 返回用户对某个物品的兴趣 85 */ 86 public static float getPreference(List<Float> uLV, List<Float> iLV){ 87 float p = 0; 88 for(int i = 0; i < latent; i++){ 89 p = p + uLV.get(i) * iLV.get(i); 90 } 91 return p; 92 } 93 94 /**预测评分差*/ 95 public static float Predict(float i1, float i2){ 96 return i1 - i2; 97 } 98 99 /** 100 * 迭代求解隐含层 101 * @param UserItem 102 */ 103 public static void LatentFactorModel(Map<Integer, Map<Integer, Float>> UserItem){ 104 SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd_HH-mm");//设置日期格式 105 for(int i = 0; i < iteration; i++){ 106 System.out.println( df.format(new Date()) + "\t第 " + (i + 1) + " 次迭代"); 107 for(int user: UserItem.keySet()){ 108 for(int item: UserItem.get(user).keySet()){ 109 float error = Predict(UserItem.get(user).get(item), 110 getPreference(UserMap.get(user), ItemMap.get(item))); 111 for(int i1 = 0; i1 < latent; i1++){ 112 UserMap.get(user).set(i1, (float) (UserMap.get(user).get(i1) + alpha * 113 (ItemMap.get(item).get(i1) * error - lambda * UserMap.get(user).get(i1)))); 114 if (Float.isNaN(UserMap.get(user).get(i1) )) { 115 System.err.println("矩阵初始化或者参数有问题导致矩阵出现数值溢出"); 116 } 117 ItemMap.get(item).set(i1, (float) (ItemMap.get(item).get(i1) + alpha * 118 (UserMap.get(user).get(i1) * error - lambda * ItemMap.get(item).get(i1)))); 119 } 120 } 121 } 122 alpha = (float) (alpha * 0.9); 123 } 124 } 125 126 /** 127 * 获取用户的最终推荐列表 128 * @param map 项目的得分值表 129 * @return 130 */ 131 public Set<Integer> getResysK(Map<Integer, Float> map){ 132 List<State> tList = new ArrayList<State>(); 133 Set<Integer> set = new HashSet<Integer>(); 134 for(Integer key: map.keySet()){ 135 tList.add(new State(key, map.get(key))); 136 } 137 Collections.sort(tList, compare); 138 for(int i = 0; i < tList.size() && i < resys; i++){ 139 set.add(tList.get(i).TemID); 140 } 141 return set; 142 } 143 144 /** 145 * 计算用户的推荐列表 146 * @param user 用户的ID 147 * @param item 用户的训练集 148 * @return 用户的推荐列表 149 */ 150 public Set<Integer> getResysList(int user, Map<Integer, Float> item){ 151 Map<Integer, Float> map = new HashMap<Integer, Float>(); 152 for(int i: ItemMap.keySet()){ 153 if(!item.containsKey(i)) 154 map.put(i, getPreference(UserMap.get(user), ItemMap.get(i))); 155 } 156 return getResysK(map); 157 } 158 }
接下来是具体的训练操作,比较重点的是训练数据和测试数据的选择,还有负例的生成
1 package org.juefan.alg.test; 2 3 import java.text.SimpleDateFormat; 4 import java.util.ArrayList; 5 import java.util.Date; 6 import java.util.HashMap; 7 import java.util.HashSet; 8 import java.util.List; 9 import java.util.Map; 10 import java.util.Set; 11 12 import org.juefan.IO.FileIO; 13 import org.juefan.alg.LFM; 14 import org.juefan.data.RatingData; 15 import org.juefan.eva.Evaluation; 16 17 public class TestLFM { 18 19 public static Set<Integer> user = new HashSet<Integer>(); 20 public static Set<Integer> item = new HashSet<Integer>(); 21 public static List<Integer> itemList = new ArrayList<Integer>(); 22 public static Map<Integer, Integer> map = new HashMap<Integer, Integer>(); 23 public static Map<Integer, Integer> randMap = new HashMap<Integer, Integer>(); //倾向选择热门且用户未评价的为负例 24 25 /**用户项目训练数据*/ 26 public static Map<Integer, Map<Integer, Float>> UserItemTrain = new HashMap<Integer, Map<Integer, Float>> (); 27 /**用户项目测试数据*/ 28 public static Map<Integer, Map<Integer, Float>> UserItemTest = new HashMap<Integer, Map<Integer, Float>> (); 29 30 public static LFM lfm = new LFM(); 31 32 public static Map<Integer, Float> getFu(Map<Integer, Float> item){ 33 Map<Integer, Float> map = new HashMap<Integer, Float>(); 34 while(map.size() < item.size()*4 && item.size() + map.size() < TestLFM.item.size() * 0.8){ 35 /**抑制热门方式*/ 36 /*int rand = (int) (Math.random() * randMap.size()); 37 if(!item.containsKey(randMap.get(rand))){ 38 map.put(randMap.get(rand), (float) 0); 39 }*/ 40 /**同等对待方式*/ 41 int rand = (int) (Math.random() * TestLFM.itemList.size()); 42 if(!item.containsKey( TestLFM.itemList.get(rand))){ 43 map.put( TestLFM.itemList.get(rand), (float) 0); 44 } 45 } 46 return map; 47 } 48 49 /**将Map的key加载进Set类*/ 50 public static Set<Integer> MapToSet(Map<Integer, Float> item){ 51 Set<Integer> tSet = new HashSet<Integer>(); 52 for(int k: item.keySet()) 53 tSet.add(k); 54 return tSet; 55 } 56 57 /** 58 * 测试入口 59 */ 60 public static void main(String[] args) { 61 System.setProperty("java.util.Arrays.useLegacyMergeSort", "true"); 62 FileIO fileIO = new FileIO(); 63 fileIO.SetfileName(System.getProperty("user.dir") + "\\data\\input\\ml-1m\\ratings.dat"); 64 fileIO.FileRead(); 65 List<String> list = fileIO.cloneList(); 66 int num = 0; 67 for(String s:list){ 68 RatingData data = new RatingData(s); 69 float rand = (float) Math.random(); 70 if(rand >= (float)1/8){ //将数据随机分成训练数据和测试数据 71 if(UserItemTrain.containsKey(data.userID)){ 72 UserItemTrain.get(data.userID).put(data.movieID, (float) 1); 73 }else { 74 Map<Integer, Float> tMap = new HashMap<Integer, Float>(); 75 tMap.put(data.movieID, (float) 1); 76 UserItemTrain.put(data.userID, tMap); 77 } 78 //计算每个项目的热度 79 if(map.containsKey(data.movieID)){ 80 map.put(data.movieID, map.get(data.movieID) + 1); 81 }else { 82 map.put(data.movieID, 1); 83 } 84 //构造项目分布映射 85 randMap.put(num++, data.movieID); 86 //收集用户列表和项目列表 87 user.add(data.userID); 88 item.add(data.movieID); 89 }else { 90 if(UserItemTest.containsKey(data.userID)){ 91 UserItemTest.get(data.userID).put(data.movieID, (float) 1); 92 }else { 93 Map<Integer, Float> tMap = new HashMap<Integer, Float>(); 94 tMap.put(data.movieID, (float) 1); 95 UserItemTest.put(data.userID, tMap); 96 } 97 } 98 } 99 100 101 System.out.println("正在构造罗盘赌"); 102 for(Integer item: TestLFM.item){ 103 itemList.add(item); 104 } 105 int Fu = 0; 106 for(int user: UserItemTrain.keySet()){ 107 UserItemTrain.get(user).putAll(getFu(UserItemTrain.get(user))); 108 if(++Fu % 1000 == 0) 109 System.out.println("已构造 " + Fu +" 个负样本用户数据"); 110 } 111 System.out.println("负样本生成完毕"); 112 113 SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd_HH-mm");//设置日期格式 114 String dataString = "\\data\\output\\Result\\" + df.format(new Date()) + "_result.txt"; 115 LFM lfm = new LFM(user, item); 116 for(int trac = 0; trac <= 20; trac++){ 117 LFM.LatentFactorModel(UserItemTrain); 118 for(int user:UserItemTrain.keySet()){ 119 if(UserItemTest.containsKey(user)){ 120 Evaluation.setEvaluation(MapToSet(UserItemTest.get(user)), lfm.getResysList(user, UserItemTrain.get(user))); 121 } 122 } 123 System.out.println("准确率 = " + Evaluation.getPrecision() * 100 + "%\t\t召回率 = " + Evaluation.getRecall() * 100 + "%\t\t覆盖率 = " + Evaluation.getCoverage()/item.size() * 100 + "%"); 124 FileIO.FileWrite(System.getProperty("user.dir") + dataString, "===================使用算法 : " + lfm.toString() 125 + "=====================\n具体参数: " 126 + "\nlatent = " + LFM.latent 127 +"\nalpha = " + LFM.alpha 128 +"\nlambda = " + LFM.lambda 129 + "\n准确率 = " + Evaluation.getPrecision() * 100 + "%\t\t召回率 = " + Evaluation.getRecall() * 100 + "%\t\t覆盖率 = " + Evaluation.getCoverage()/item.size() * 100 + "%\n", true); 130 } 131 } 132 133 }
好了,基本上就这样了
如果要看完整的代码欢迎到本人的Github上查看,里面还有相应的数据,还有一个UserCF代码模块
Github地址:https://github.com/JueFan/RecommendSystem
参考文章: http://blog.csdn.net/harryhuang1990/article/details/9924377#reply