apriori算法
Apriori算法简单实现#
前言#
以如下数据为例,来说明算法的运行过程,找出其频繁项。数据中每一行代表一条数据,每一列可以代表待关联的事物,比如每个客户购买的每个商品
[['a','c','e'],
['b','d'],
['b','c'],
['a','b','c','d'],
['a','b'],
['b','c'],
['a','b'],
['a','b','c','e'],
['a','b','c'],
['a','c','e']]
为了便于处理,将数据转化为如下格式,其中每个元素为1,表示用户购买了该商品,为0表示没有购买
a b c d e
0 1 0 1 0 1
1 0 1 0 1 0
2 0 1 1 0 0
3 1 1 1 1 0
4 1 1 0 0 0
5 0 1 1 0 0
6 1 1 0 0 0
7 1 1 1 0 1
8 1 1 1 0 0
9 1 0 1 0 1
最后,设置阈值为3。
第一轮处理#
首先得到个数为1的项集,然后统计数据中每一列值为1的个数(其实就是统计每一个项集的个数),最后将个数小于阈值的项集去除,并保存在stop表中。这样得到第一轮处理后的频繁项集。
开始循环处理#
循环结束的条件是,频繁项集的个数为0。
首先,利用上一步得到的频繁项集进行连接(即求各个集合的并),然后去除重复的项集和项集中元素错误的(如第二次处理时,每个项集元素个数应该为2,如果出现个数为1,为3的,就要删除)。然后进行剪枝操作(即去除子集在stop表中的项集),最后统计项集在原数据中的个数,最后将个数小于阈值的项集去除,并保存在stop表中。这样得到第二轮处理后的频繁项集。
统计时采用的策略如下,以统计ac的次数为例,只要找到a,c对应的列,将每一行相加,然后统计其和为2的个数,就是ac出现的次数:
之后进行循环,直至满足循环结束条件。
以上述数据为例的算法过程可视化如下:#
程序运行流程及结果如下:#
- 去掉['d']
- 频繁项[(['a', 'b'], 5), (['a', 'c'], 5), (['a', 'e'], 3), (['c', 'b'], 5), (['e', 'c'], 3)]
- 去掉[['e', 'b']]
- 频繁项[(['a', 'c', 'b'], 3), (['a', 'c', 'e'], 3)]
- 去掉[]
- 频繁项[]
- 去掉[['a', 'b', 'e', 'c']]
- 结果:[(['a', 'b'], 5), (['c', 'a'], 5), (['a', 'e'], 3), (['c', 'b'], 5), (['c', 'e'], 3), (['c', 'a', 'b'], 3), (['c', 'a', 'e'], 3)]
运行可视化如下:#
代码#
from copy import copy
import pandas as pd
def Apriori(data,th):
#保存最终频繁项
result=[]
#得到要统计的每一项
col=data.columns.tolist()
#统计每一项个数
tmp=data[data==1].count().tolist()
#根据每一项统计的数目,删除比阈值小的
stop=[]
f=list(zip(col,tmp))
f_copy = copy(f)
for i, j in f_copy:
if j<th:
stop.append(i)
f.remove((i,j))
# result.extend(f)
# print('频繁项')
# print(f)
print('去掉')
print(stop)
turn=1
while True:
turn=turn+1
# 得到要统计的每一项
tmp = []
col=[]
d=[i[0] for i in f]
col_set = []
#连接操作
for i in d:
for j in d[d.index(i) + 1:]:
#取并集
item=set(i).union(set(j))
#删除连接后不符合要求的元素
#删除个数不对的,删除重复的
if (len(item)==turn) and (item not in col_set) :
#剪枝操作,删除子集就不是频繁项的
if (len(stop)!=0):
judge=[set(n).issubset(item) for n in stop]
if (True not in judge):
col_set.append(item)
else:
col_set.append(item)
col=[list(i) for i in col_set]
#直到没有频繁项就跳出循环
if len(col)==0:
break
# 统计每一项个数
for i in col:
tmp.append(data[data[i].sum(axis=1) == turn].count().tolist()[0])
# 根据每一项统计的数目,删除比阈值小的
stop = []
f = list(zip(col, tmp))
f_copy=copy(f)
for i, j in f_copy:
if j < th:
stop.append(i)
f.remove((i, j))
result.extend(f)
print('频繁项')
print(f)
print('去掉')
print(stop)
print('结果')
print(result)
data=pd.DataFrame({'a':[1,0,0,1,1,0,1,1,1,1],'b':[0,1,1,1,1,1,1,1,1,0],'c':[1,0,1,1,0,1,0,1,1,1],'d':[0,1,0,1,0,0,0,0,0,0],'e':[1,0,0,0,0,0,0,1,0,1]})
print(data)
Apriori(data,3)
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性