Apriori算法
给出项集支持度的定义:数据集中包含该项集的数据的比例
置信度\(A\rightarrow B\)的定义:即\(P(B\mid A) = \frac{P(AB)}{P(A)}\),其中\(P(x)\)为项集\(x\)的支持度
算法流程:
先算出所有满足最小支持度的频繁项集,这个可以迭代来算,具体做法就是:
- 先找出所有的大小为\(1\)的候选项集列表\(C_1\)(\(C_1\)即所有不同元素构成的列表)
- 在候选项集列表\(C_k\)中找出满足最小支持度的项集,加入到元素数量为\(k\)频繁项集列表\(L_k\)中
- 根据\(L_k\)合并出下一个候选项集列表\(C_{k+1}\)
- 重复步骤\(2\)
合并出下一个候选项集列表操作的算法原理就是:如果某个项集不是频繁项集,那么这个项集的所有超集都不是频繁项集。这样做剪枝可以去掉很多没用的状态
然后根据这些频繁项集来计算满足最小置信度的规则
这里的做法是:
先枚举每一个频繁项集
对于一个频繁项集会有很多可能满足最小置信度的规则,具体来说,如果一个频繁项集存在\(n\)个元素,那么存在\(2^n-2\)个规则(所有的组合再去掉全集和空集)
但是可以发现如果一个规则\(A\rightarrow B\)没有达到最小置信度,那么规则\((A-sub)\rightarrow (B|sub)\)也达不到最小置信度(其中\(sub\)为\(A\)的子集,\(|\)表示集合并,\(-\)表示集合差),也就是\(A\)的所有子集都不满足条件了
那么就可以和找频繁项集一样剪枝掉很多状态了
对于每一个频繁项集,先把所有单个的元素作为规则后件集合形成一个规则后件集合\(list\),然后计算每个规则是否满足条件
把满足条件的前件集合放到\(list\)中,和找频繁项集中的操作\(3\)一样找出合法的前件集合,再通过集合相减找出合法的后件集合,递归去做上一个操作就好了
下面是代码
view code
#coding:utf-8
# generate data
def genData():
return [['牛奶','啤酒','尿布'],
['牛奶','面包','黄油'],
['牛奶','尿布','饼干'],
['面包','黄油','饼干'],
['啤酒','尿布','饼干'],
['牛奶','尿布','面包','黄油'],
['尿布','面包','黄油'],
['啤酒','尿布'],
['牛奶','尿布','面包','黄油'],
['啤酒','饼干'] ]
def loadDataSet():
return [[1, 3, 4], [2, 3, 5], [1, 2, 3, 5], [2, 5]]
# 传入参数:数据集
# 返回值:候选项集C1
def genC1(datalist)->[frozenset]:
goodsset = set()
for items in datalist:
for goods in items:
goodsset.add(goods)
C = list()
for goods in goodsset:
C.append(frozenset([goods]))
return C
# 传入参数:频繁项集list
# 返回:下一个候选项集->list(frozenset)
def mergeToNext(preL):
if len(preL) == 0:
return []
Ck = list()
k = len(preL[0])
for i in range(len(preL)):
for j in range(i+1,len(preL)):
A = sorted([x for x in preL[i]])[:k-1]
B = sorted([x for x in preL[j]])[:k-1]
if A == B:
Ck.append(preL[i] | preL[j])
return Ck
# 传入参数:数据集,候选项集,最小支持度
# 返回值:频繁项集->list(frozenset),频繁项集支持度->dict
def genfreq(dataset, preC, minsupport):
objfreq = dict()
L = list()
for item in preC:
__appcnt = 0
for data in dataset:
if (item&data) == item:
__appcnt += 1
if __appcnt / len(dataset) >= minsupport:
L.append(item)
objfreq[item] = __appcnt / len(dataset)
return L, objfreq
# 传入参数:频繁项,规则后集,支持度集合,规则集合,最小置信度
# 无返回值
def GetRules(freqset, R, suppotdata, rulelist, minconf):
if len(R)==0 or len(R[0])==len(freqset):
return
legalconseq = list()
for ret in R:
# P(A|B) = P(AB) / P(B)
conseq = freqset - ret
conf = supportdata[freqset] / supportdata[conseq]
if conf >= minconf:
rulelist.append([conseq,ret,conf])
legalconseq.append(conseq)
nextconseqlist = mergeToNext(legalconseq)
nextR = list()
for conseq in nextconseqlist:
nextR.append(freqset-conseq)
if len(nextR)==0 or len(nextR[0])==0:
return
GetRules(freqset,nextR,supportdata,rulelist,minconf)
# 传入参数:各长度频繁项集,频繁项集支持度,最小置信度
# 返回值:规则列表以及置信度
def genRules(Llist, supportdata, minconf = .5):
rulelist = list()
for i in range(1,len(Llist)):
L = Llist[i]
if len(L) == 0:
break
for freqset in L:
R = [frozenset([x]) for x in freqset]
GetRules(freqset,R,supportdata,rulelist,minconf)
return rulelist
# 传入参数:数据集,最小支持度
# 返回值:各长度频繁项集->list(list(frozenset)),频繁项集支持度->dist
def apriori(datalist, minsupport = .5):
# C1 -> L1 ---merge---> C2
dataset = list(map(frozenset,[x for x in datalist]))
supportdata = dict()
Llist = list()
C = genC1(dataset)
while len(C) != 0:
L, tmpfreq = genfreq(dataset,C,minsupport)
Llist.append(L)
supportdata.update(tmpfreq)
C = mergeToNext(Llist[-1])
return Llist, supportdata
if __name__ == "__main__":
# datalist = genData()
datalist = loadDataSet()
Llist, supportdata = apriori(datalist)
rulelist = genRules(Llist,supportdata)
# for L in Llist:
# for p in L:
# print(p,supportdata[p])
for rule in rulelist:
print(rule[0],'->',rule[1],'conf = ',rule[2])
处理Online_Retail数据的代码
只处理\(France\)部分
view code
import pandas as pd
# !/usr/bin/python
# coding:utf-8
# author: kiko
# 传入参数:数据集
# 返回值:候选项集C1
def genC1(datalist)->[frozenset]:
goodsset = set()
for items in datalist:
for goods in items:
goodsset.add(goods)
C = list()
for goods in goodsset:
C.append(frozenset([goods]))
return C
# 传入参数:频繁项集list
# 返回:下一个候选项集->list(frozenset)
def mergeToNext(preL):
if len(preL) == 0:
return []
Ck = list()
k = len(preL[0])
for i in range(len(preL)):
for j in range(i+1,len(preL)):
A = sorted([x for x in preL[i]])[:k-1]
B = sorted([x for x in preL[j]])[:k-1]
if A == B:
Ck.append(preL[i] | preL[j])
return Ck
# 传入参数:数据集,候选项集,最小支持度
# 返回值:频繁项集->list(frozenset),频繁项集支持度->dict
def genfreq(dataset, preC, minsupport):
objfreq = dict()
L = list()
for item in preC:
__appcnt = 0
for data in dataset:
if (item&data) == item:
__appcnt += 1
if __appcnt / len(dataset) >= minsupport:
L.append(item)
objfreq[item] = __appcnt / len(dataset)
return L, objfreq
# 传入参数:频繁项,规则后集,支持度集合,规则集合,最小置信度
# 无返回值
def GetRules(freqset, R, suppotdata, rulelist, minconf):
if len(R)==0 or len(R[0])==len(freqset):
return
legalconseq = list()
for ret in R:
# P(A|B) = P(AB) / P(B)
conseq = freqset - ret
conf = supportdata[freqset] / supportdata[conseq]
if conf >= minconf:
rulelist.append([conseq,ret,conf])
legalconseq.append(conseq)
nextconseqlist = mergeToNext(legalconseq)
nextR = list()
for conseq in nextconseqlist:
nextR.append(freqset-conseq)
if len(nextR)==0 or len(nextR[0])==0:
return
GetRules(freqset,nextR,supportdata,rulelist,minconf)
# 传入参数:各长度频繁项集,频繁项集支持度,最小置信度
# 返回值:规则列表以及置信度
def genRules(Llist, supportdata, minconf = .5):
rulelist = list()
for i in range(1,len(Llist)):
L = Llist[i]
if len(L) == 0:
break
for freqset in L:
R = [frozenset([x]) for x in freqset]
GetRules(freqset,R,supportdata,rulelist,minconf)
return rulelist
# 传入参数:数据集,最小支持度
# 返回值:各长度频繁项集->list(list(frozenset)),频繁项集支持度->dist
def apriori(datalist, minsupport = .5):
# C1 -> L1 ---merge---> C2
dataset = list(map(frozenset,[x for x in datalist]))
supportdata = dict()
Llist = list()
C = genC1(dataset)
while len(C) != 0:
L, tmpfreq = genfreq(dataset,C,minsupport)
Llist.append(L)
supportdata.update(tmpfreq)
C = mergeToNext(Llist[-1])
return Llist, supportdata
def encode_units(x):
if x <= 0:
return 0
if x >= 1:
return 1
def test():
datalist = loadDataSet()
datalist = genData()
Llist, supportdata = apriori(datalist)
rulelist = genRules(Llist,supportdata)
for L in Llist:
for p in L:
print(p,supportdata[p])
for rule in rulelist:
print(rule[0],'->',rule[1],'conf = ',rule[2])
def getFranceData():
df = pd.read_excel('xxx/Online_Retail.xlsx')
df = df[df['Country'].str.contains('France')]
df.to_excel('xxx/FranceData.xlsx')
if __name__ == "__main__":
# getFranceData()
df = pd.read_excel('xxx/FranceData.xlsx')
# print('Is all France' if len(df[df['Country'].str.contains('France')])==df.shape[0] else 'Not all France')
df['Description'] = df['Description'].str.strip()
df['InvoiceNo'] = df['InvoiceNo'].astype('str')
df = df[~df['InvoiceNo'].str.contains('C')]
#显示所有列
pd.set_option('display.max_columns', None)
#显示所有行
pd.set_option('display.max_rows', None)
basket = (df.groupby(['InvoiceNo', 'Description'])['Quantity'].sum()
.unstack().reset_index().fillna(0)
.set_index('InvoiceNo'))
basket_sets = basket.applymap(encode_units)
basket_sets.drop('POSTAGE', inplace=True, axis=1)
datalist = list()
for ind in basket_sets.index:
singledata = list()
for col in basket_sets.columns:
if basket_sets.loc[ind,col] == 1:
singledata.append(col)
datalist.append(singledata)
Llist, supportdata = apriori(datalist,minsupport = .07)
rulelist = genRules(Llist,supportdata,minconf = .7)
for rule in rulelist:
print(rule[0],'->',rule[1],'conf = ',rule[2])