程序简介
项目以ml-100k电影评分数据集为输入,实现了基于用户的协同过滤算法,最后预测的MAE为0.84,因为经过优化,10万条评分数据运行时间不超过2分钟
协同过滤算法(CF)基于对用户历史行为数据的挖掘发现用户的喜好偏向,并预测用户可能喜好的产品进行推荐。也就是常见的“猜你喜欢”,和“购买了该商品的人也喜欢”等功能。
程序/数据集下载
代码分析
导入模块、路径、预设参数
# -*- coding: utf-8 -*-
import numpy as np
from numba import jit
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error as MAE
import os
######################参数#################################
topK = 20#近邻用户数
testRate = 0.2#测试比例
seed = 2#随机种子
testCount = 1000#最终参与测试行为数(不然运行太慢)
##########################################################
#用来正常显示中文标签
plt.rcParams['font.sans-serif']=['SimHei']
#用来正常显示负号
plt.rcParams['axes.unicode_minus']=False
#路径目录
baseDir = ''#当前目录
staticDir = os.path.join(baseDir,'Static')#静态文件目录
resultDir = os.path.join(baseDir,'Result')#结果文件目录
载入数据,设置索引(这样会快很多很多),查看数据前5行
#读数据
users = pd.read_table(staticDir+'/电影评分数据/u.user',sep="|",names=['user_id','age','sex','occupation','zip_code'],encoding='latin-1',engine='python')
ratings = pd.read_table(staticDir+'/电影评分数据/u.data', sep='\t', names=['user_id', 'movie_id', 'rating', 'unix_timestamp'],encoding='latin-1',engine='python')
movies = pd.read_table(staticDir+'/电影评分数据/u.item',engine='python', sep='|',header=None,encoding='latin-1',names=['movie_id','title','release_date','video_release_date','IMDb_URL','unknown','Action','Adventure','Animation','Children','Comedy','Crime','Documentary','Drama','Fantasy','Film-Noir','Horror','Musical','Mystery','Romance','Sci-Fi','Thriller','War','Western'])
movies = movies.iloc[:,:5]
#设置索引
users = users.set_index(['user_id'],drop=False)
movies = movies.set_index(['movie_id'],drop=False)
ratings = ratings.set_index(['user_id','movie_id'],drop=False)
用户数据
users.head()
user_id | age | sex | occupation | zip_code | |
---|---|---|---|---|---|
user_id | |||||
1 | 1 | 24 | M | technician | 85711 |
2 | 2 | 53 | F | other | 94043 |
3 | 3 | 23 | M | writer | 32067 |
4 | 4 | 24 | M | technician | 43537 |
5 | 5 | 33 | F | other | 15213 |
电影数据
movies.head()
movie_id | title | release_date | video_release_date | IMDb_URL | |
---|---|---|---|---|---|
movie_id | |||||
1 | 1 | Toy Story (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Toy%20Story%2... |
2 | 2 | GoldenEye (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?GoldenEye%20(... |
3 | 3 | Four Rooms (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Four%20Rooms%... |
4 | 4 | Get Shorty (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Get%20Shorty%... |
5 | 5 | Copycat (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Copycat%20(1995) |
评分数据(行为数据)
ratings.head()
user_id | movie_id | rating | unix_timestamp | ||
---|---|---|---|---|---|
user_id | movie_id | ||||
196 | 242 | 196 | 242 | 3 | 881250949 |
186 | 302 | 186 | 302 | 3 | 891717742 |
22 | 377 | 22 | 377 | 1 | 878887116 |
244 | 51 | 244 | 51 | 2 | 880606923 |
166 | 346 | 166 | 346 | 1 | 886397596 |
分割评分数据集为测试集和训练集
#分割行为数据成测试集和训练集
np.random.seed(seed)
testIndex = np.random.choice(range(ratings.shape[0]),size=int(ratings.shape[0]*testRate),replace=False)
testRatings = ratings.iloc[testIndex,:]
trainIndex = list(set(range(ratings.shape[0]))-set(testIndex))
trainRatings = ratings.iloc[trainIndex,:]
电影数据新增users列,统计出每部电影被哪些用户观看过,查看前5行
#统计所有电影在训练集被观看的用户id
def calUsers(movieId):
#观影记录
try:
views = trainRatings.loc[(slice(None),movieId),:]
except:
return []
users = views['user_id'].values.tolist()
return users
movies['users'] = movies['movie_id'].apply(calUsers)
movies.head()
movie_id | title | release_date | video_release_date | IMDb_URL | users | |
---|---|---|---|---|---|---|
movie_id | ||||||
1 | 1 | Toy Story (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Toy%20Story%2... | [308, 287, 148, 280, 66, 109, 181, 95, 189, 14... |
2 | 2 | GoldenEye (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?GoldenEye%20(... | [5, 268, 276, 87, 250, 201, 64, 13, 213, 373, ... |
3 | 3 | Four Rooms (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Four%20Rooms%... | [181, 81, 130, 49, 320, 145, 95, 99, 267, 417,... |
4 | 4 | Get Shorty (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Get%20Shorty%... | [99, 19, 207, 295, 201, 10, 308, 328, 109, 334... |
5 | 5 | Copycat (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Copycat%20(1995) | [293, 43, 311, 109, 344, 145, 314, 308, 280, 4... |
用户数据新增movies列,统计所有用户在训练集看过的电影id,查看前5行
#统计所有用户在训练集看过的电影id
def calMovies(userId):
#观影记录
try:
views = trainRatings.loc[(userId,slice(None)),:]
except:
return []
movies = views['movie_id'].values.tolist()
return movies
users['movies'] = users['user_id'].apply(calMovies)
users.head()
user_id | age | sex | occupation | zip_code | movies | |
---|---|---|---|---|---|---|
user_id | ||||||
1 | 1 | 24 | M | technician | 85711 | [61, 189, 33, 160, 20, 202, 171, 265, 117, 47,... |
2 | 2 | 53 | F | other | 94043 | [292, 251, 314, 297, 312, 281, 13, 303, 308, 2... |
3 | 3 | 23 | M | writer | 32067 | [335, 245, 337, 343, 323, 331, 294, 332, 334, ... |
4 | 4 | 24 | M | technician | 43537 | [264, 303, 361, 357, 260, 356, 294, 288, 50, 2... |
5 | 5 | 33 | F | other | 15213 | [2, 439, 225, 110, 454, 424, 363, 98, 102, 211... |
计算相似度矩阵,查看前5行5列,元素即每个用户之间相互的相似度,计算公式如下
#计算相似度矩阵
sims = pd.DataFrame(0,columns=users.index,index=users.index)
def calSim(userId1,userId2):
#用户1的物品
user1Items = users.loc[userId1,'movies']
#物品2的用户
user2Items = users.loc[userId2,'movies']
#两个物品共同用户
cross = list(set(user1Items) & set(user2Items))
#相似度
sim = len(cross)/((max(1e-1,len(user1Items))*max(1e-1,len(user2Items)))**0.5)
return sim
def fillSims(row):
userIds = pd.Series(row.index)
row[:] = userIds.apply(calSim,args=(row.name,))
return row
sims = sims.apply(fillSims,axis=1)
sims.iloc[:5,:5]
user_id | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|
user_id | |||||
1 | 1.000000 | 0.116008 | 0.069067 | 0.064449 | 0.275500 |
2 | 0.116008 | 1.000000 | 0.142887 | 0.133333 | 0.036380 |
3 | 0.069067 | 0.142887 | 1.000000 | 0.238145 | 0.012377 |
4 | 0.064449 | 0.133333 | 0.238145 | 1.000000 | 0.040423 |
5 | 0.275500 | 0.036380 | 0.012377 | 0.040423 | 1.000000 |
用户数据新增nears列,统计用户的邻近用户,查看前5行
#计算用户的邻近用户
def calNearUsers(userId):
#该物品的相似度向量,选取topK个物品
nearUserIds = sims.loc[:,userId].sort_values(ascending=False)[1:topK+1]
nearUserIds = nearUserIds.index.tolist()
return nearUserIds
users['near'] = users['user_id'].apply(calNearUsers)
users.head()
user_id | age | sex | occupation | zip_code | movies | near | |
---|---|---|---|---|---|---|---|
user_id | |||||||
1 | 1 | 24 | M | technician | 85711 | [61, 189, 33, 160, 20, 202, 171, 265, 117, 47,... | [457, 435, 916, 648, 933, 276, 864, 297, 805, ... |
2 | 2 | 53 | F | other | 94043 | [292, 251, 314, 297, 312, 281, 13, 303, 308, 2... | [701, 673, 926, 131, 306, 569, 937, 520, 486, ... |
3 | 3 | 23 | M | writer | 32067 | [335, 245, 337, 343, 323, 331, 294, 332, 334, ... | [752, 489, 784, 587, 863, 529, 783, 428, 126, ... |
4 | 4 | 24 | M | technician | 43537 | [264, 303, 361, 357, 260, 356, 294, 288, 50, 2... | [33, 816, 750, 408, 443, 783, 725, 596, 355, 6... |
5 | 5 | 33 | F | other | 15213 | [2, 439, 225, 110, 454, 424, 363, 98, 102, 211... | [222, 648, 407, 56, 495, 254, 497, 457, 727, 1... |
对测试集评分进行预测,查看前5行,最终测试MAE为0.84,公式如下
def predict(row):
'''预测评分'''
userId = row['user_id']
movieId = row['movie_id']
#topK个邻近用户
nearUserIds = users.loc[userId,'near']
#用户在训练集的物品
itemUserIds = movies.loc[movieId,'users']
#物品交集
cross = list(set(nearUserIds) & set(itemUserIds))
#预测评分
up = 0#分母
down = 0#分子
for nearUserId in cross:
sim = sims.loc[nearUserId,userId]
down += sim
#用户对邻近物品评分
score = trainRatings.loc[(nearUserId,movieId),'rating']
up += score * sim
if up == 0:
return None
else:
return up/down
#开始测试
testRatings['predict'] = testRatings.apply(predict,axis=1)
testRatings = testRatings.dropna()
mae = MAE(testRatings['rating'],testRatings['predict'])
print('测试集MAE为%.2f'%mae)
testRatings.head()
测试集MAE为0.84
user_id | movie_id | rating | unix_timestamp | predict | ||
---|---|---|---|---|---|---|
user_id | movie_id | |||||
157 | 273 | 157 | 273 | 5 | 886889876 | 3.740361 |
405 | 1065 | 405 | 1065 | 1 | 885546069 | 3.790151 |
244 | 550 | 244 | 550 | 1 | 880602264 | 3.164083 |
378 | 768 | 378 | 768 | 4 | 880333598 | 2.786181 |
919 | 111 | 919 | 111 | 4 | 875288681 | 3.691650 |