使用LFM(Latent factor model)隐语义模型进行Top-N推荐

最近在拜读项亮博士的《推荐系统实践》,系统的学习一下推荐系统的相关知识。今天学习了其中的隐语义模型在Top-N推荐中的应用,在此做一个总结。

隐语义模型LFM和LSI,LDA,Topic Model其实都属于隐含语义分析技术,是一类概念,他们在本质上是相通的,都是找出潜在的主题或分类。这些技术一开始都是在文本挖掘领域中提出来的,近些年它们也被不断应用到其他领域中,并得到了不错的应用效果。比如,在推荐系统中它能够基于用户的行为对item进行自动聚类,也就是把item划分到不同类别/主题,这些主题/类别可以理解为用户的兴趣。

对于一个用户来说,他们可能有不同的兴趣。就以作者举的豆瓣书单的例子来说,用户A会关注数学,历史,计算机方面的书,用户B喜欢机器学习,编程语言,离散数学方面的书, 用户C喜欢大师Knuth, Jiawei Han等人的著作。那我们在推荐的时候,肯定是向用户推荐他感兴趣的类别下的图书。那么前提是我们要对所有item(图书)进行分类。那如何分呢?大家注意到没有,分类标准这个东西是因人而异的,每个用户的想法都不一样。拿B用户来说,他喜欢的三个类别其实都可以算作是计算机方面的书籍,也就是说B的分类粒度要比A小;拿离散数学来讲,他既可以算作数学,也可当做计算机方面的类别,也就是说有些item不能简单的将其划归到确定的单一类别;拿C用户来说,他倾向的是书的作者,只看某几个特定作者的书,那么跟A,B相比它的分类角度就完全不同了。

显然我们不能靠由单个人(编辑)或team的主观想法建立起来的分类标准对整个平台用户喜好进行标准化。

此外我们还需要注意的两个问题:

  1. 我们在可见的用户书单中归结出3个类别,不等于该用户就只喜欢这3类,对其他类别的书就一点兴趣也没有。也就是说,我们需要了解用户对于所有类别的兴趣度。
  2. 对于一个给定的类来说,我们需要确定这个类中每本书属于该类别的权重。权重有助于我们确定该推荐哪些书给用户。

下面我们就来看看LFM是如何解决上面的问题的?对于一个给定的用户行为数据集(数据集包含的是所有的user, 所有的item,以及每个user有过行为的item列表),使用LFM对其建模后,我们可以得到如下图所示的模型:(假设数据集中有3个user, 4个item, LFM建模的分类数为4)

 

R矩阵是user-item矩阵,矩阵值Rij表示的是user i 对item j的兴趣度,这正是我们要求的值。对于一个user来说,当计算出他对所有item的兴趣度后,就可以进行排序并作出推荐。LFM算法从数据集中抽取出若干主题,作为user和item之间连接的桥梁,将R矩阵表示为P矩阵和Q矩阵相乘。其中P矩阵是user-class矩阵,矩阵值Pij表示的是user i对class j的兴趣度;Q矩阵式class-item矩阵,矩阵值Qij表示的是item j在class i中的权重,权重越高越能作为该类的代表。所以LFM根据如下公式来计算用户U对物品I的兴趣度

我们发现使用LFM后, 

  1. 我们不需要关心分类的角度,结果都是基于用户行为统计自动聚类的,全凭数据自己说了算。
  2. 不需要关心分类粒度的问题,通过设置LFM的最终分类数就可控制粒度,分类数越大,粒度约细。
  3. 对于一个item,并不是明确的划分到某一类,而是计算其属于每一类的概率,是一种标准的软分类。
  4. 对于一个user,我们可以得到他对于每一类的兴趣度,而不是只关心可见列表中的那几个类。
  5. 对于每一个class,我们可以得到类中每个item的权重,越能代表这个类的item,权重越高。

那么,接下去的问题就是如何计算矩阵P和矩阵Q中参数值。一般做法就是最优化损失函数来求参数。在定义损失函数之前,我们需要准备一下数据集并对兴趣度的取值做一说明。


数据集应该包含所有的user和他们有过行为的(也就是喜欢)的item。所有的这些item构成了一个item全集。对于每个user来说,我们把他有过行为的item称为正样本,规定兴趣度RUI=1,此外我们还需要从item全集中随机抽样,选取与正样本数量相当的样本作为负样本,规定兴趣度为RUI=0。因此,兴趣的取值范围为[0,1]。


采样之后原有的数据集得到扩充,得到一个新的user-item集K={(U,I)},其中如果(U,I)是正样本,则RUI=1,否则RUI=0。损失函数如下所示:

上式中的是用来防止过拟合的正则化项,λ需要根据具体应用场景反复实验得到。损失函数的优化使用随机梯度下降算法:

  1. 通过求参数PUK和QKI的偏导确定最快的下降方向;

  1. 迭代计算不断优化参数(迭代次数事先人为设置),直到参数收敛。



其中,α是学习速率,α越大,迭代下降的越快。α和λ一样,也需要根据实际的应用场景反复实验得到。本书中,作者在MovieLens数据集上进行实验,他取分类数F=100,α=0.02,λ=0.01。
               【注意】:书中在上面四个式子中都缺少了


综上所述,执行LFM需要:

  1. 根据数据集初始化P和Q矩阵(这是我暂时没有弄懂的地方,这个初始化过程到底是怎么样进行的,还恳请各位童鞋予以赐教。)
  2. 确定4个参数:分类数F,迭代次数N,学习速率α,正则化参数λ。

LFM的伪代码可以表示如下:

 

[python] view plaincopy
 
  1. def LFM(user_items, F, N, alpha, lambda):  
  2.     #初始化P,Q矩阵  
  3.     [P, Q] = InitModel(user_items, F)  
  4.     #开始迭代  
  5.     For step in range(0, N):  
  6.         #从数据集中依次取出user以及该user喜欢的iterms集  
  7.         for user, items in user_item.iterms():  
  8.             #随机抽样,为user抽取与items数量相当的负样本,并将正负样本合并,用于优化计算  
  9.             samples = RandSelectNegativeSamples(items)  
  10.             #依次获取item和user对该item的兴趣度  
  11.             for item, rui in samples.items():  
  12.                 #根据当前参数计算误差  PS:转载的博客中rui写成了eui
  13.                 eui = rui - Predict(user, item)  
  14.                 #优化参数  
  15.                 for f in range(0, F):  
  16.                     P[user][f] += alpha * (eui * Q[f][item] - lambda * P[user][f])  
  17.                     Q[f][item] += alpha * (eui * P[user][f] - lambda * Q[f][item])  
  18.         #每次迭代完后,都要降低学习速率。一开始的时候由于离最优值相差甚远,因此快速下降;  
  19.         #当优化到一定程度后,就需要放慢学习速率,慢慢的接近最优值。  
  20.         alpha *= 0.9  


本人对书中的伪代码追加了注释,有不对的地方还请指正。

 


当估算出P和Q矩阵后,我们就可以使用(*)式计算用户U对各个item的兴趣度值,并将兴趣度值最高的N个iterm(即TOP N)推荐给用户。

总结来说,LFM具有成熟的理论基础,它是一个纯种的学习算法,通过最优化理论来优化指定的参数,建立最优的模型.

 

 

========================我是分割线我自豪=============================

我不懂Python, 所以按照书里的步骤用java实现了, 期间走了好多弯路

在这里说一下, 初始值的选择, 在迭代的时候alpha的初始值切记不要选太大了, 我之前一直用0.1, 然后每次都没收敛, 晕死我了, 还一直改代码+调试

以为是其它原因, 浪费了大半天时间

最后不小心把alpha改为0.5后发现就正常了

改时间把代码扔上来

 

准备下班了........................

 

-------------------------------------------------我是不需要理由的分割线----------------------------------------------------

时隔好久,为了准备面试,重要把之前的代码复习了一遍,顺便整理好放到Github上

现在就扔到博客上来吧

LFM的核心代码模块

  1 package org.juefan.alg;
  2 
  3 import java.text.SimpleDateFormat;
  4 import java.util.ArrayList;
  5 import java.util.Collections;
  6 import java.util.Comparator;
  7 import java.util.Date;
  8 import java.util.HashMap;
  9 import java.util.HashSet;
 10 import java.util.List;
 11 import java.util.Map;
 12 import java.util.Set;
 13 
 14 public class LFM {
 15     
 16     public static final int latent = 100;
 17     public static double alpha = 0.03;
 18     public static double lambda = 0.01;
 19     public static final int iteration = 1;
 20     public static final   int resys = 10;
 21     
 22     public static Map<Integer, List<Float>> UserMap = new HashMap<Integer, List<Float>>();
 23     public static Map<Integer, List<Float>> ItemMap = new HashMap<Integer, List<Float>>();
 24         
 25     public static compares compare = new compares();
 26     
 27     public  class State{
 28         public int TemID;
 29         public Set<Integer> set = new HashSet<Integer>();
 30         public float sim;
 31 
 32         /**用户集排序*/
 33         public State(Set<Integer> s, float s2){
 34             set.addAll(s);
 35             sim = s2;
 36         }
 37 
 38         /**Item排序*/
 39         public State(Integer i, float s){
 40             TemID = i;
 41             sim = s;
 42         }
 43     }
 44 
 45     public static class compares implements Comparator<Object>{
 46         @Override
 47         public int compare(Object o1, Object o2) {            
 48             State s1 = (State)o1;
 49             State s2 = (State)o2;
 50             return s1.sim < s2.sim ? 1:-1;
 51         }        
 52     }
 53     
 54     public  String toString(){
 55         return "LFM";
 56     }
 57     /**
 58      * 加载用户与项目的集合并初始化隐含矩阵
 59      * 注意隐含层的数值不能太大,建议在0.05左右
 60      * @param user
 61      * @param item
 62      */
 63     public LFM(Set<Integer> user, Set<Integer> item){
 64         for(Integer u:user){
 65             List<Float> tList = new ArrayList<Float>();
 66             for(int i = 0; i < latent; i++)
 67                 tList.add((float) ((float) Math.random() * 0.1));
 68             UserMap.put(u, tList);
 69         }
 70         for(Integer u:item){
 71             List<Float> tList = new ArrayList<Float>();
 72             for(int i = 0; i < latent; i++)
 73                 tList.add((float) ((float) Math.random() *0.1));
 74             ItemMap.put(u, tList);
 75         }
 76     }
 77     public LFM() {
 78         // TODO Auto-generated constructor stub
 79     }
 80     /**
 81      *  计算用户对某个物品的兴趣
 82      * @param uLV    用户与隐含类的关系
 83      * @param iLV    隐含类与物品的关系
 84      * @return 返回用户对某个物品的兴趣
 85      */
 86     public static float getPreference(List<Float> uLV, List<Float> iLV){
 87         float p = 0;
 88         for(int i = 0; i < latent; i++){
 89             p = p + uLV.get(i) * iLV.get(i);
 90         }
 91         return p;
 92     }
 93     
 94     /**预测评分差*/
 95     public static float Predict(float i1, float i2){
 96         return i1 - i2;
 97     }
 98     
 99     /**
100      * 迭代求解隐含层
101      * @param UserItem    
102      */
103     public static void LatentFactorModel(Map<Integer, Map<Integer, Float>> UserItem){
104         SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd_HH-mm");//设置日期格式
105         for(int i = 0; i < iteration; i++){    
106             System.out.println( df.format(new Date()) + "\t第 " + (i + 1) + " 次迭代");
107             for(int user: UserItem.keySet()){
108                 for(int item: UserItem.get(user).keySet()){
109                     float error = Predict(UserItem.get(user).get(item), 
110                             getPreference(UserMap.get(user), ItemMap.get(item)));
111                     for(int i1 = 0; i1 < latent; i1++){
112                         UserMap.get(user).set(i1, (float) (UserMap.get(user).get(i1) + alpha * 
113                                 (ItemMap.get(item).get(i1) * error - lambda * UserMap.get(user).get(i1))));
114                         if (Float.isNaN(UserMap.get(user).get(i1) )) {
115                             System.err.println("矩阵初始化或者参数有问题导致矩阵出现数值溢出");
116                         }
117                         ItemMap.get(item).set(i1, (float) (ItemMap.get(item).get(i1) + alpha * 
118                                 (UserMap.get(user).get(i1) * error - lambda * ItemMap.get(item).get(i1))));
119                     }
120                 }
121             }
122             alpha = (float) (alpha * 0.9);
123         }    
124     }
125     
126     /**
127      * 获取用户的最终推荐列表
128      * @param map 项目的得分值表
129      * @return
130      */
131     public Set<Integer> getResysK(Map<Integer, Float> map){
132         List<State> tList = new ArrayList<State>();
133         Set<Integer> set = new HashSet<Integer>();
134         for(Integer key: map.keySet()){
135             tList.add(new State(key,  map.get(key)));        
136         }
137         Collections.sort(tList, compare);
138         for(int i = 0; i < tList.size() && i < resys; i++){
139             set.add(tList.get(i).TemID);    
140         }
141         return set;
142     }
143     
144     /**
145      *  计算用户的推荐列表
146      * @param user    用户的ID
147      * @param item    用户的训练集
148      * @return    用户的推荐列表
149      */
150     public Set<Integer> getResysList(int user, Map<Integer, Float> item){
151         Map<Integer, Float> map = new HashMap<Integer, Float>();
152         for(int i: ItemMap.keySet()){
153             if(!item.containsKey(i))
154                 map.put(i, getPreference(UserMap.get(user), ItemMap.get(i)));
155         }
156         return getResysK(map);
157     }
158 }

接下来是具体的训练操作,比较重点的是训练数据和测试数据的选择,还有负例的生成

  1 package org.juefan.alg.test;
  2 
  3 import java.text.SimpleDateFormat;
  4 import java.util.ArrayList;
  5 import java.util.Date;
  6 import java.util.HashMap;
  7 import java.util.HashSet;
  8 import java.util.List;
  9 import java.util.Map;
 10 import java.util.Set;
 11 
 12 import org.juefan.IO.FileIO;
 13 import org.juefan.alg.LFM;
 14 import org.juefan.data.RatingData;
 15 import org.juefan.eva.Evaluation;
 16 
 17 public class TestLFM {
 18 
 19     public static Set<Integer> user = new HashSet<Integer>();
 20     public static Set<Integer> item = new HashSet<Integer>();
 21     public static List<Integer> itemList = new ArrayList<Integer>();
 22     public static Map<Integer, Integer> map = new HashMap<Integer, Integer>();
 23     public static Map<Integer, Integer> randMap = new HashMap<Integer, Integer>();    //倾向选择热门且用户未评价的为负例
 24 
 25     /**用户项目训练数据*/
 26     public static Map<Integer, Map<Integer, Float>>  UserItemTrain = new HashMap<Integer, Map<Integer, Float>> ();
 27     /**用户项目测试数据*/
 28     public static Map<Integer, Map<Integer, Float>>  UserItemTest = new HashMap<Integer, Map<Integer, Float>> ();
 29     
 30     public static LFM lfm = new LFM();
 31         
 32     public static Map<Integer, Float> getFu(Map<Integer, Float> item){
 33         Map<Integer, Float> map = new HashMap<Integer, Float>();    
 34         while(map.size() < item.size()*4 && item.size() + map.size() < TestLFM.item.size() * 0.8){
 35             /**抑制热门方式*/
 36             /*int rand = (int) (Math.random() * randMap.size());
 37             if(!item.containsKey(randMap.get(rand))){
 38                 map.put(randMap.get(rand), (float) 0);
 39             }*/
 40             /**同等对待方式*/
 41             int rand = (int) (Math.random() *  TestLFM.itemList.size());
 42             if(!item.containsKey( TestLFM.itemList.get(rand))){
 43                 map.put( TestLFM.itemList.get(rand), (float) 0);
 44             }
 45         }
 46         return map;
 47     }
 48 
 49     /**将Map的key加载进Set类*/
 50     public static Set<Integer> MapToSet(Map<Integer, Float> item){
 51         Set<Integer> tSet = new HashSet<Integer>();
 52         for(int k: item.keySet())
 53             tSet.add(k);
 54         return tSet;
 55     }
 56 
 57     /**
 58      * 测试入口
 59      */
 60     public static void main(String[] args) {
 61         System.setProperty("java.util.Arrays.useLegacyMergeSort", "true");  
 62         FileIO fileIO = new FileIO();
 63         fileIO.SetfileName(System.getProperty("user.dir") + "\\data\\input\\ml-1m\\ratings.dat");
 64         fileIO.FileRead();
 65         List<String> list = fileIO.cloneList();
 66         int num = 0;
 67         for(String s:list){
 68             RatingData data = new RatingData(s);
 69             float rand = (float) Math.random();
 70             if(rand >= (float)1/8){    //将数据随机分成训练数据和测试数据
 71                 if(UserItemTrain.containsKey(data.userID)){
 72                     UserItemTrain.get(data.userID).put(data.movieID, (float) 1);
 73                 }else {
 74                     Map<Integer, Float> tMap = new HashMap<Integer, Float>();
 75                     tMap.put(data.movieID, (float) 1);
 76                     UserItemTrain.put(data.userID, tMap);
 77                 }
 78                 //计算每个项目的热度
 79                 if(map.containsKey(data.movieID)){
 80                     map.put(data.movieID, map.get(data.movieID) + 1);
 81                 }else {
 82                     map.put(data.movieID, 1);
 83                 }
 84                 //构造项目分布映射
 85                 randMap.put(num++, data.movieID);
 86                 //收集用户列表和项目列表
 87                 user.add(data.userID);
 88                 item.add(data.movieID);
 89             }else {
 90                 if(UserItemTest.containsKey(data.userID)){
 91                     UserItemTest.get(data.userID).put(data.movieID, (float) 1);
 92                 }else {
 93                     Map<Integer, Float> tMap = new HashMap<Integer, Float>();
 94                     tMap.put(data.movieID, (float) 1);
 95                     UserItemTest.put(data.userID, tMap);
 96                 }
 97             }        
 98         }
 99         
100 
101         System.out.println("正在构造罗盘赌");
102         for(Integer item: TestLFM.item){
103             itemList.add(item);
104         }
105         int Fu = 0;
106         for(int user: UserItemTrain.keySet()){
107             UserItemTrain.get(user).putAll(getFu(UserItemTrain.get(user)));
108             if(++Fu % 1000 == 0)
109                 System.out.println("已构造 " + Fu +" 个负样本用户数据");
110         }
111         System.out.println("负样本生成完毕");
112 
113         SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd_HH-mm");//设置日期格式
114         String dataString = "\\data\\output\\Result\\" + df.format(new Date()) + "_result.txt";
115         LFM lfm = new LFM(user, item);
116         for(int trac = 0; trac <= 20; trac++){
117             LFM.LatentFactorModel(UserItemTrain);
118             for(int user:UserItemTrain.keySet()){
119                 if(UserItemTest.containsKey(user)){
120                     Evaluation.setEvaluation(MapToSet(UserItemTest.get(user)), lfm.getResysList(user, UserItemTrain.get(user)));
121                 }
122             }
123             System.out.println("准确率 = " + Evaluation.getPrecision() * 100 + "%\t\t召回率 = " + Evaluation.getRecall() * 100 + "%\t\t覆盖率 = " + Evaluation.getCoverage()/item.size() * 100 + "%");
124             FileIO.FileWrite(System.getProperty("user.dir") + dataString, "===================使用算法 : " + lfm.toString()
125                     + "=====================\n具体参数: "            
126                     + "\nlatent = " + LFM.latent
127                     +"\nalpha = " + LFM.alpha
128                     +"\nlambda = " + LFM.lambda
129                     + "\n准确率 = " + Evaluation.getPrecision() * 100 + "%\t\t召回率 = " + Evaluation.getRecall() * 100 + "%\t\t覆盖率 = " + Evaluation.getCoverage()/item.size() * 100 + "%\n", true);
130         }
131     }
132 
133 }

好了,基本上就这样了

如果要看完整的代码欢迎到本人的Github上查看,里面还有相应的数据,还有一个UserCF代码模块

Github地址:https://github.com/JueFan/RecommendSystem

参考文章: http://blog.csdn.net/harryhuang1990/article/details/9924377#reply

posted on 2013-12-05 16:21  JueFan_C  阅读(1642)  评论(1编辑  收藏  举报

导航