推荐系统偏好SVD - C++
最近想整整推荐系统,比较经典的算法就是SVD了。具体理论不多讲了。直接上代码。
先贴张效果图吧。userNum 6040 itemNum 3900
本文链接:http://www.cnblogs.com/wn19910213/p/3617781.html
上代码咯:
SVD.h
1 #ifndef SVD_H_INCLUDED 2 #define SVD_H_INCLUDED 3 4 #include <vector> 5 #include <cstring> 6 7 using namespace std; 8 9 class SVD{ 10 public: 11 SVD(double*,double*,int,double**,double**); 12 ~SVD(); 13 14 void loadTrainFile(string); 15 double predictScore(int,double,double,double*,double*); 16 double Validate(string,double,double*,double*,double**,double**); 17 // private: 18 double* Bi; 19 double* Bu; 20 int factor; 21 double** Qi; 22 double** Pu; 23 }; 24 25 26 #endif // SVD_H_INCLUDED
SVD.cpp
1 #include <cmath> 2 #include <iostream> 3 #include <cstring> 4 #include <cstdlib> 5 #include <fstream> 6 #include "SVD.h" 7 8 9 int userNum = 6040; 10 int itemNum = 3900; 11 double AVG = 3.579231; 12 double lr = 0.01; 13 double theta = 0.05; 14 double preRmse = 1000000.0; 15 16 int main() 17 { 18 string trainFile = "/home/ja/CADATA/SVD/ml_data/training.txt"; 19 string testFile = "/home/ja/CADATA/SVD/ml_data/test.txt"; 20 srand(0); 21 SVD svd(NULL,NULL,0,NULL,NULL); 22 23 for(size_t i=0;i<100;i++){ 24 svd.loadTrainFile(trainFile); 25 //lr *= 0.9; 26 double curRmse = svd.Validate(testFile,AVG,svd.Bu,svd.Bi,svd.Pu,svd.Qi); 27 cout << "test_Rmse in step " << i << ": " << curRmse << endl; 28 if(curRmse >= preRmse){ 29 break; 30 } 31 else{ 32 preRmse = curRmse; 33 } 34 } 35 return 0; 36 } 37 38 double SVD::Validate(string testfile,double avg,double* bu,double* bi,double** pu,double** qi){ 39 ifstream fin(testfile.c_str()); 40 if(!fin){ 41 cout << "error" << endl; 42 } 43 int userId,itemId,rating,t; 44 int n = 0; 45 double rmse; 46 while(fin >> userId >> itemId >> rating >> t){ 47 n++; 48 double pScore = predictScore(avg,bu[userId-1],bi[itemId-1],pu[userId-1],qi[itemId-1]); 49 rmse += (rating - pScore) * (rating - pScore); 50 } 51 fin.close(); 52 return sqrt(rmse/n); 53 } 54 55 double SVD::predictScore(int avg,double bu,double bi,double* pu,double* qi){ 56 double tmp = 0.0; 57 for(size_t i=0;i<factor;i++){ 58 tmp += pu[i] * qi[i]; 59 } 60 61 double score = avg + bu + bi + tmp; 62 if(score > 5){ 63 score = 5; 64 } 65 if(score < 1){ 66 score = 1; 67 } 68 return score; 69 } 70 71 void SVD::loadTrainFile(string file){ 72 ifstream fin(file.c_str()); 73 if(!fin){ 74 cout << "error" << endl; 75 } 76 77 int userId,itemId,rating,t; 78 while(fin >> userId >> itemId >> rating >> t){ 79 double predict = predictScore(AVG,Bu[userId-1],Bi[itemId-1],Pu[userId-1],Qi[itemId-1]); 80 double error = rating - predict; 81 Bu[userId-1] += lr * (error - theta * Bu[userId-1]); 82 Bi[itemId-1] += lr * (error - theta * Bi[itemId-1]); 83 84 for(size_t i=0;i<factor;i++){ 85 double Tmp = Pu[userId-1][i]; 86 Pu[userId-1][i] += lr * (error * Qi[itemId-1][i] - theta * Pu[userId-1][i]); 87 Qi[itemId-1][i] += lr * (error * Tmp - theta * Qi[itemId-1][i]); 88 } 89 } 90 fin.close(); 91 } 92 93 SVD::SVD(double* bi,double* bu,int k,double** qi,double** pu){ 94 95 if(bi == NULL){ 96 Bi = new double[itemNum]; 97 for(size_t i=0;i<itemNum;i++){ 98 Bi[i] = 0.0; 99 } 100 } 101 else{ 102 Bi = bi; 103 } 104 105 if(bu == NULL){ 106 Bu = new double[userNum]; 107 for(size_t i=0;i<userNum;i++){ 108 Bu[i] = 0.0; 109 } 110 } 111 else{ 112 Bu = bu; 113 } 114 115 factor = 10; 116 117 if(qi == NULL){ 118 Qi = new double* [itemNum]; 119 for(size_t i=0;i<itemNum;i++){ 120 Qi[i] = new double[factor]; 121 } 122 123 for(size_t i=0;i<itemNum;i++){ 124 for(size_t j=0;j<factor;j++){ 125 Qi[i][j] = 0.1 * (rand() / (RAND_MAX + 1.0)) / sqrt(factor); 126 } 127 } 128 } 129 else{ 130 Qi = qi; 131 } 132 133 if(pu == NULL){ 134 Pu = new double* [userNum]; 135 for(size_t i=0;i<userNum;i++){ 136 Pu[i] = new double[factor]; 137 } 138 139 for(size_t i=0;i<userNum;i++){ 140 for(size_t j=0;j<factor;j++){ 141 Pu[i][j] = 0.1 * (rand() / (RAND_MAX + 1.0)) / sqrt(factor); 142 } 143 } 144 } 145 else{ 146 Pu = pu; 147 } 148 } 149 150 SVD::~SVD(){ 151 delete[] Bi; 152 delete[] Bu; 153 for(size_t i=0;i<userNum;i++){ 154 delete[] Pu[i]; 155 } 156 for(size_t i=0;i<itemNum;i++){ 157 delete[] Qi[i]; 158 } 159 delete[] Pu; 160 delete[] Qi; 161 }