CBA算法---基于关联规则进行分类的算法
更多数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
介绍
CBA算法全称是Classification base of Association,就是基于关联规则进行分类的算法,说到关联规则,我们就会想到Apriori和FP-Tree算法都是关联规则挖掘算法,而CBA算法正是利用了Apriori挖掘出的关联规则,然后做分类判断,所以在某种程度上说,CBA算法也可以说是一种集成挖掘算法。
算法原理
CBA算法作为分类算法,他的分类情况也就是给定一些预先知道的属性,然后叫你判断出他的决策属性是哪个值。判断的依据就是Apriori算法挖掘出的频繁项,如果一个项集中包含预先知道的属性,同时也包含分类属性值,然后我们计算此频繁项能否导出已知属性值推出决策属性值的关联规则,如果满足规则的最小置信度的要求,那么可以把频繁项中的决策属性值作为最后的分类结果。具体的算法细节如下:
1、输入数据记录,就是一条条的属性值。
2、对属性值做数字的替换(按照列从上往下寻找属性值),就类似于Apriori中的一条条事务记录。
3、根据这个转化后的事务记录,进行Apriori算法计算,挖掘出频繁项集。
4、输入查询的属性值,找出符合条件的频繁项集(需要包含查询属性和分类决策属性),如果能够推导出这样的关联规则,就算分类成功,输出分类结果。
这里以之前我做的CART算法的测试数据为CBA算法的测试数据,如下:
Rid Age Income Student CreditRating BuysComputer
1 13 High No Fair CLassNo
2 11 High No Excellent CLassNo
3 25 High No Fair CLassYes
4 45 Medium No Fair CLassYes
5 50 Low Yes Fair CLassYes
6 51 Low Yes Excellent CLassNo
7 30 Low Yes Excellent CLassYes
8 13 Medium No Fair CLassNo
9 9 Low Yes Fair CLassYes
10 55 Medium Yes Fair CLassYes
11 14 Medium Yes Excellent CLassYes
12 33 Medium No Excellent CLassYes
13 33 High Yes Fair CLassYes
14 41 Medium No Excellent CLassNo
属性值对应的数字替换图:
Medium=5, CLassYes=12, Excellent=10, Low=6, Fair=9, CLassNo=11, Young=1, Middle_aged=2, Yes=8, No=7, High=4, Senior=3
体会之后的数据变为了下面的事务项:
Rid Age Income Student CreditRating BuysComputer
1 1 4 7 9 11
2 1 4 7 10 11
3 2 4 7 9 12
4 3 5 7 9 12
5 3 6 8 9 12
6 3 6 8 10 11
7 2 6 8 10 12
8 1 5 7 9 11
9 1 6 8 9 12
10 3 5 8 9 12
11 1 5 8 10 12
12 2 5 7 10 12
13 2 4 8 9 12
14 3 5 7 10 11
把每条记录看出事务项,就和Apriori算法的输入格式基本一样了,后面就是进行连接运算和剪枝步骤等Apriori算法的步骤了,在这里就不详细描述了,Apriori算法的实现可以点击这里进行了解。
算法的代码实现
测试数据就是上面的内容。
CBATool.java:
package DataMining_CBA;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import DataMining_CBA.AprioriTool.AprioriTool;
import DataMining_CBA.AprioriTool.FrequentItem;
/**
* CBA算法(关联规则分类)工具类
*
* @author lyq
*
*/
public class CBATool {
// 年龄的类别划分
public final String AGE = "Age";
public final String AGE_YOUNG = "Young";
public final String AGE_MIDDLE_AGED = "Middle_aged";
public final String AGE_Senior = "Senior";
// 测试数据地址
private String filePath;
// 最小支持度阈值率
private double minSupportRate;
// 最小置信度阈值,用来判断是否能够成为关联规则
private double minConf;
// 最小支持度
private int minSupportCount;
// 属性列名称
private String[] attrNames;
// 类别属性所代表的数字集合
private ArrayList<Integer> classTypes;
// 用二维数组保存测试数据
private ArrayList<String[]> totalDatas;
// Apriori算法工具类
private AprioriTool aprioriTool;
// 属性到数字的映射图
private HashMap<String, Integer> attr2Num;
private HashMap<Integer, String> num2Attr;
public CBATool(String filePath, double minSupportRate, double minConf) {
this.filePath = filePath;
this.minConf = minConf;
this.minSupportRate = minSupportRate;
readDataFile();
}
/**
* 从文件中读取数据
*/
private void readDataFile() {
File file = new File(filePath);
ArrayList<String[]> dataArray = new ArrayList<String[]>();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
totalDatas = new ArrayList<>();
for (String[] array : dataArray) {
totalDatas.add(array);
}
attrNames = totalDatas.get(0);
minSupportCount = (int) (minSupportRate * totalDatas.size());
attributeReplace();
}
/**
* 属性值的替换,替换成数字的形式,以便进行频繁项的挖掘
*/
private void attributeReplace() {
int currentValue = 1;
int num = 0;
String s;
// 属性名到数字的映射图
attr2Num = new HashMap<>();
num2Attr = new HashMap<>();
classTypes = new ArrayList<>();
// 按照1列列的方式来,从左往右边扫描,跳过列名称行和id列
for (int j = 1; j < attrNames.length; j++) {
for (int i = 1; i < totalDatas.size(); i++) {
s = totalDatas.get(i)[j];
// 如果是数字形式的,这里只做年龄类别转换,其他的数字情况类似
if (attrNames[j].equals(AGE)) {
num = Integer.parseInt(s);
if (num <= 20 && num > 0) {
totalDatas.get(i)[j] = AGE_YOUNG;
} else if (num > 20 && num <= 40) {
totalDatas.get(i)[j] = AGE_MIDDLE_AGED;
} else if (num > 40) {
totalDatas.get(i)[j] = AGE_Senior;
}
}
if (!attr2Num.containsKey(totalDatas.get(i)[j])) {
attr2Num.put(totalDatas.get(i)[j], currentValue);
num2Attr.put(currentValue, totalDatas.get(i)[j]);
if (j == attrNames.length - 1) {
// 如果是组后一列,说明是分类类别列,记录下来
classTypes.add(currentValue);
}
currentValue++;
}
}
}
// 对原始的数据作属性替换,每条记录变为类似于事务数据的形式
for (int i = 1; i < totalDatas.size(); i++) {
for (int j = 1; j < attrNames.length; j++) {
s = totalDatas.get(i)[j];
if (attr2Num.containsKey(s)) {
totalDatas.get(i)[j] = attr2Num.get(s) + "";
}
}
}
}
/**
* Apriori计算全部频繁项集
* @return
*/
private ArrayList<FrequentItem> aprioriCalculate() {
String[] tempArray;
ArrayList<FrequentItem> totalFrequentItems;
ArrayList<String[]> copyData = (ArrayList<String[]>) totalDatas.clone();
// 去除属性名称行
copyData.remove(0);
// 去除首列ID
for (int i = 0; i < copyData.size(); i++) {
String[] array = copyData.get(i);
tempArray = new String[array.length - 1];
System.arraycopy(array, 1, tempArray, 0, tempArray.length);
copyData.set(i, tempArray);
}
aprioriTool = new AprioriTool(copyData, minSupportCount);
aprioriTool.computeLink();
totalFrequentItems = aprioriTool.getTotalFrequentItems();
return totalFrequentItems;
}
/**
* 基于关联规则的分类
*
* @param attrValues
* 预先知道的一些属性
* @return
*/
public String CBAJudge(String attrValues) {
int value = 0;
// 最终分类类别
String classType = null;
String[] tempArray;
// 已知的属性值
ArrayList<String> attrValueList = new ArrayList<>();
ArrayList<FrequentItem> totalFrequentItems;
totalFrequentItems = aprioriCalculate();
// 将查询条件进行逐一属性的分割
String[] array = attrValues.split(",");
for (String record : array) {
tempArray = record.split("=");
value = attr2Num.get(tempArray[1]);
attrValueList.add(value + "");
}
// 在频繁项集中寻找符合条件的项
for (FrequentItem item : totalFrequentItems) {
// 过滤掉不满足个数频繁项
if (item.getIdArray().length < (attrValueList.size() + 1)) {
continue;
}
// 要保证查询的属性都包含在频繁项集中
if (itemIsSatisfied(item, attrValueList)) {
tempArray = item.getIdArray();
classType = classificationBaseRules(tempArray);
if (classType != null) {
// 作属性替换
classType = num2Attr.get(Integer.parseInt(classType));
break;
}
}
}
return classType;
}
/**
* 基于关联规则进行分类
*
* @param items
* 频繁项
* @return
*/
private String classificationBaseRules(String[] items) {
String classType = null;
String[] arrayTemp;
int count1 = 0;
int count2 = 0;
// 置信度
double confidenceRate;
String[] noClassTypeItems = new String[items.length - 1];
for (int i = 0, k = 0; i < items.length; i++) {
if (!classTypes.contains(Integer.parseInt(items[i]))) {
noClassTypeItems[k] = items[i];
k++;
} else {
classType = items[i];
}
}
for (String[] array : totalDatas) {
// 去除ID数字号
arrayTemp = new String[array.length - 1];
System.arraycopy(array, 1, arrayTemp, 0, array.length - 1);
if (isStrArrayContain(arrayTemp, noClassTypeItems)) {
count1++;
if (isStrArrayContain(arrayTemp, items)) {
count2++;
}
}
}
// 做置信度的计算
confidenceRate = count1 * 1.0 / count2;
if (confidenceRate >= minConf) {
return classType;
} else {
// 如果不满足最小置信度要求,则此关联规则无效
return null;
}
}
/**
* 判断单个字符是否包含在字符数组中
*
* @param array
* 字符数组
* @param s
* 判断的单字符
* @return
*/
private boolean strIsContained(String[] array, String s) {
boolean isContained = false;
for (String str : array) {
if (str.equals(s)) {
isContained = true;
break;
}
}
return isContained;
}
/**
* 数组array2是否包含于array1中,不需要完全一样
*
* @param array1
* @param array2
* @return
*/
private boolean isStrArrayContain(String[] array1, String[] array2) {
boolean isContain = true;
for (String s2 : array2) {
isContain = false;
for (String s1 : array1) {
// 只要s2字符存在于array1中,这个字符就算包含在array1中
if (s2.equals(s1)) {
isContain = true;
break;
}
}
// 一旦发现不包含的字符,则array2数组不包含于array1中
if (!isContain) {
break;
}
}
return isContain;
}
/**
* 判断频繁项集是否满足查询
*
* @param item
* 待判断的频繁项集
* @param attrValues
* 查询的属性值列表
* @return
*/
private boolean itemIsSatisfied(FrequentItem item,
ArrayList<String> attrValues) {
boolean isContained = false;
String[] array = item.getIdArray();
for (String s : attrValues) {
isContained = true;
if (!strIsContained(array, s)) {
isContained = false;
break;
}
if (!isContained) {
break;
}
}
if (isContained) {
isContained = false;
// 还要验证是否频繁项集中是否包含分类属性
for (Integer type : classTypes) {
if (strIsContained(array, type + "")) {
isContained = true;
break;
}
}
}
return isContained;
}
}
调用类Client.java:
package DataMining_CBA;
import java.text.MessageFormat;
/**
* CBA算法--基于关联规则的分类算法
* @author lyq
*
*/
public class Client {
public static void main(String[] args){
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
String attrDesc = "Age=Senior,CreditRating=Fair";
String classification = null;
//最小支持度阈值率
double minSupportRate = 0.2;
//最小置信度阈值
double minConf = 0.7;
CBATool tool = new CBATool(filePath, minSupportRate, minConf);
classification = tool.CBAJudge(attrDesc);
System.out.println(MessageFormat.format("{0}的关联分类结果为{1}", attrDesc, classification));
}
}
代码的结果为:
频繁1项集:
{1,},{2,},{3,},{4,},{5,},{6,},{7,},{8,},{9,},{10,},{11,},{12,},
频繁2项集:
{1,7,},{1,9,},{1,11,},{2,12,},{3,5,},{3,8,},{3,9,},{3,12,},{4,7,},{4,9,},{5,7,},{5,9,},{5,10,},{5,12,},{6,8,},{6,12,},{7,9,},{7,10,},{7,11,},{7,12,},{8,9,},{8,10,},{8,12,},{9,12,},{10,11,},{10,12,},
频繁3项集:
{1,7,11,},{3,9,12,},{6,8,12,},{8,9,12,},
频繁4项集:
频繁5项集:
频繁6项集:
频繁7项集:
频繁8项集:
频繁9项集:
频繁10项集:
频繁11项集:
Age=Senior,CreditRating=Fair的关联分类结果为CLassYes
上面的有些项集为空说明没有此项集。Apriori算法类可以在这里进行查阅,这里只展示了CBA算法的部分。
算法的分析
我在准备实现CBA算法的时候就预见到了这个算法就是对Apriori算法的一个包装,在于2点,输入数据的格式进行数字的转换,还有就是输出的时候做属性对数字的替换,核心还是在于Apriori算法的项集频繁挖掘。
程序实现时遇到的问题
在这期间遇到了一个bug就是频繁1项集在排序的时候出现了问题,后来发现原因是String.CompareTo(),原本应该是1,2,....11,12,用了前面这个方法后会变成1,10,2,。。就是10会比2小的情况,后来查了String.CompareTo()的比较规则,明白了他是一位位比较Ascall码值,因为10的1比2小,最后果断的改回了用Integer的比较方法了。这个问题别看是个小问题,1项集如果没有排好序,后面的连接操作可能会出现少情况的可能,这个之前吃过这样的亏了。
我对CBA算法的理解
CBA算法和巧妙的利用了关联规则进行类别的分类,有别与其他的分类算法。他的算法好坏又会依靠Apriori算法的执行好坏。