推荐系统偏好SVD - C++

最近想整整推荐系统,比较经典的算法就是SVD了。具体理论不多讲了。直接上代码。

先贴张效果图吧。userNum 6040 itemNum 3900

本文链接:http://www.cnblogs.com/wn19910213/p/3617781.html

上代码咯:

SVD.h

 1 #ifndef SVD_H_INCLUDED
 2 #define SVD_H_INCLUDED
 3 
 4 #include <vector>
 5 #include <cstring>
 6 
 7 using namespace std;
 8 
 9 class SVD{
10     public:
11         SVD(double*,double*,int,double**,double**);
12         ~SVD();
13 
14         void loadTrainFile(string);
15         double predictScore(int,double,double,double*,double*);
16         double Validate(string,double,double*,double*,double**,double**);
17 //    private:
18         double* Bi;
19         double* Bu;
20         int factor;
21         double** Qi;
22         double** Pu;
23 };
24 
25 
26 #endif // SVD_H_INCLUDED

SVD.cpp

  1 #include <cmath>
  2 #include <iostream>
  3 #include <cstring>
  4 #include <cstdlib>
  5 #include <fstream>
  6 #include "SVD.h"
  7 
  8 
  9 int userNum = 6040;
 10 int itemNum = 3900;
 11 double AVG = 3.579231;
 12 double lr = 0.01;
 13 double theta = 0.05;
 14 double preRmse = 1000000.0;
 15 
 16 int main()
 17 {
 18     string trainFile = "/home/ja/CADATA/SVD/ml_data/training.txt";
 19     string testFile = "/home/ja/CADATA/SVD/ml_data/test.txt";
 20     srand(0);
 21     SVD svd(NULL,NULL,0,NULL,NULL);
 22 
 23     for(size_t i=0;i<100;i++){
 24         svd.loadTrainFile(trainFile);
 25         //lr *= 0.9;
 26         double curRmse = svd.Validate(testFile,AVG,svd.Bu,svd.Bi,svd.Pu,svd.Qi);
 27         cout << "test_Rmse in step " << i << ": " << curRmse << endl;
 28         if(curRmse >= preRmse){
 29             break;
 30         }
 31         else{
 32             preRmse = curRmse;
 33         }
 34     }
 35     return 0;
 36 }
 37 
 38 double SVD::Validate(string testfile,double avg,double* bu,double* bi,double** pu,double** qi){
 39     ifstream fin(testfile.c_str());
 40     if(!fin){
 41         cout << "error" << endl;
 42     }
 43     int userId,itemId,rating,t;
 44     int n = 0;
 45     double rmse;
 46     while(fin >> userId >> itemId >> rating >> t){
 47         n++;
 48         double pScore = predictScore(avg,bu[userId-1],bi[itemId-1],pu[userId-1],qi[itemId-1]);
 49         rmse += (rating - pScore) * (rating - pScore);
 50     }
 51     fin.close();
 52     return sqrt(rmse/n);
 53 }
 54 
 55 double SVD::predictScore(int avg,double bu,double bi,double* pu,double* qi){
 56     double tmp = 0.0;
 57     for(size_t i=0;i<factor;i++){
 58         tmp += pu[i] * qi[i];
 59     }
 60 
 61     double score = avg + bu + bi + tmp;
 62     if(score > 5){
 63         score = 5;
 64     }
 65     if(score < 1){
 66         score = 1;
 67     }
 68     return score;
 69 }
 70 
 71 void SVD::loadTrainFile(string file){
 72     ifstream fin(file.c_str());
 73     if(!fin){
 74         cout << "error" << endl;
 75     }
 76 
 77     int userId,itemId,rating,t;
 78     while(fin >> userId >> itemId >> rating >> t){
 79         double predict = predictScore(AVG,Bu[userId-1],Bi[itemId-1],Pu[userId-1],Qi[itemId-1]);
 80         double error = rating - predict;
 81         Bu[userId-1] += lr * (error - theta * Bu[userId-1]);
 82         Bi[itemId-1] += lr * (error - theta * Bi[itemId-1]);
 83 
 84         for(size_t i=0;i<factor;i++){
 85             double Tmp = Pu[userId-1][i];
 86             Pu[userId-1][i] += lr * (error * Qi[itemId-1][i] - theta * Pu[userId-1][i]);
 87             Qi[itemId-1][i] += lr * (error * Tmp - theta * Qi[itemId-1][i]);
 88         }
 89     }
 90     fin.close();
 91 }
 92 
 93 SVD::SVD(double* bi,double* bu,int k,double** qi,double** pu){
 94 
 95     if(bi == NULL){
 96         Bi = new double[itemNum];
 97         for(size_t i=0;i<itemNum;i++){
 98             Bi[i] = 0.0;
 99         }
100     }
101     else{
102         Bi = bi;
103     }
104 
105     if(bu == NULL){
106         Bu = new double[userNum];
107         for(size_t i=0;i<userNum;i++){
108             Bu[i] = 0.0;
109         }
110     }
111     else{
112         Bu = bu;
113     }
114 
115     factor = 10;
116 
117     if(qi == NULL){
118         Qi = new double* [itemNum];
119         for(size_t i=0;i<itemNum;i++){
120             Qi[i] = new double[factor];
121         }
122 
123         for(size_t i=0;i<itemNum;i++){
124             for(size_t j=0;j<factor;j++){
125                 Qi[i][j] = 0.1 * (rand() / (RAND_MAX + 1.0)) / sqrt(factor);
126             }
127         }
128     }
129     else{
130         Qi = qi;
131     }
132 
133     if(pu == NULL){
134         Pu = new double* [userNum];
135         for(size_t i=0;i<userNum;i++){
136             Pu[i] = new double[factor];
137         }
138 
139         for(size_t i=0;i<userNum;i++){
140             for(size_t j=0;j<factor;j++){
141                 Pu[i][j] = 0.1 * (rand() / (RAND_MAX + 1.0)) / sqrt(factor);
142             }
143         }
144     }
145     else{
146         Pu = pu;
147     }
148 }
149 
150 SVD::~SVD(){
151     delete[] Bi;
152     delete[] Bu;
153     for(size_t i=0;i<userNum;i++){
154         delete[] Pu[i];
155     }
156     for(size_t i=0;i<itemNum;i++){
157         delete[] Qi[i];
158     }
159     delete[] Pu;
160     delete[] Qi;
161 }

 

posted on 2014-03-22 16:58  Ja °  阅读(1188)  评论(0编辑  收藏  举报

导航