apriori算法

Apriori算法简单实现#

前言#

以如下数据为例,来说明算法的运行过程,找出其频繁项。数据中每一行代表一条数据,每一列可以代表待关联的事物,比如每个客户购买的每个商品

Copy
[['a','c','e'], ['b','d'], ['b','c'], ['a','b','c','d'], ['a','b'], ['b','c'], ['a','b'], ['a','b','c','e'], ['a','b','c'], ['a','c','e']]

为了便于处理,将数据转化为如下格式,其中每个元素为1,表示用户购买了该商品,为0表示没有购买

Copy
a b c d e 0 1 0 1 0 1 1 0 1 0 1 0 2 0 1 1 0 0 3 1 1 1 1 0 4 1 1 0 0 0 5 0 1 1 0 0 6 1 1 0 0 0 7 1 1 1 0 1 8 1 1 1 0 0 9 1 0 1 0 1

最后,设置阈值为3。

第一轮处理#

首先得到个数为1的项集,然后统计数据中每一列值为1的个数(其实就是统计每一个项集的个数),最后将个数小于阈值的项集去除,并保存在stop表中。这样得到第一轮处理后的频繁项集。

开始循环处理#

循环结束的条件是,频繁项集的个数为0。
首先,利用上一步得到的频繁项集进行连接(即求各个集合的并),然后去除重复的项集和项集中元素错误的(如第二次处理时,每个项集元素个数应该为2,如果出现个数为1,为3的,就要删除)。然后进行剪枝操作(即去除子集在stop表中的项集),最后统计项集在原数据中的个数,最后将个数小于阈值的项集去除,并保存在stop表中。这样得到第二轮处理后的频繁项集。
统计时采用的策略如下,以统计ac的次数为例,只要找到a,c对应的列,将每一行相加,然后统计其和为2的个数,就是ac出现的次数:

之后进行循环,直至满足循环结束条件。

以上述数据为例的算法过程可视化如下:#

程序运行流程及结果如下:#

  • 去掉['d']
  • 频繁项[(['a', 'b'], 5), (['a', 'c'], 5), (['a', 'e'], 3), (['c', 'b'], 5), (['e', 'c'], 3)]
  • 去掉[['e', 'b']]
  • 频繁项[(['a', 'c', 'b'], 3), (['a', 'c', 'e'], 3)]
  • 去掉[]
  • 频繁项[]
  • 去掉[['a', 'b', 'e', 'c']]
  • 结果:[(['a', 'b'], 5), (['c', 'a'], 5), (['a', 'e'], 3), (['c', 'b'], 5), (['c', 'e'], 3), (['c', 'a', 'b'], 3), (['c', 'a', 'e'], 3)]

运行可视化如下:#

代码#

Copy
from copy import copy import pandas as pd def Apriori(data,th): #保存最终频繁项 result=[] #得到要统计的每一项 col=data.columns.tolist() #统计每一项个数 tmp=data[data==1].count().tolist() #根据每一项统计的数目,删除比阈值小的 stop=[] f=list(zip(col,tmp)) f_copy = copy(f) for i, j in f_copy: if j<th: stop.append(i) f.remove((i,j)) # result.extend(f) # print('频繁项') # print(f) print('去掉') print(stop) turn=1 while True: turn=turn+1 # 得到要统计的每一项 tmp = [] col=[] d=[i[0] for i in f] col_set = [] #连接操作 for i in d: for j in d[d.index(i) + 1:]: #取并集 item=set(i).union(set(j)) #删除连接后不符合要求的元素 #删除个数不对的,删除重复的 if (len(item)==turn) and (item not in col_set) : #剪枝操作,删除子集就不是频繁项的 if (len(stop)!=0): judge=[set(n).issubset(item) for n in stop] if (True not in judge): col_set.append(item) else: col_set.append(item) col=[list(i) for i in col_set] #直到没有频繁项就跳出循环 if len(col)==0: break # 统计每一项个数 for i in col: tmp.append(data[data[i].sum(axis=1) == turn].count().tolist()[0]) # 根据每一项统计的数目,删除比阈值小的 stop = [] f = list(zip(col, tmp)) f_copy=copy(f) for i, j in f_copy: if j < th: stop.append(i) f.remove((i, j)) result.extend(f) print('频繁项') print(f) print('去掉') print(stop) print('结果') print(result) data=pd.DataFrame({'a':[1,0,0,1,1,0,1,1,1,1],'b':[0,1,1,1,1,1,1,1,1,0],'c':[1,0,1,1,0,1,0,1,1,1],'d':[0,1,0,1,0,0,0,0,0,0],'e':[1,0,0,0,0,0,0,1,0,1]}) print(data) Apriori(data,3)
posted @   启林O_o  阅读(78)  评论(0编辑  收藏  举报
编辑推荐:
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
阅读排行:
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性
点击右上角即可分享
微信分享提示
CONTENTS