ID3算法

ID3是数据挖掘分类中的一种(是一种if-then的模式),其中运用到熵的概念,表示随机变量不确定性的度量

H(x)=-∑pi *log pi

信息增益是指特征A对训练数据集D的信息增益g(D,A),定义为集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差

g(D,A)=H(D)-H(D|A)

其中H(Y|X)=∑pi H(Y|X=xi)

Pi=P(x=xi)

ID3 是一种自顶向下增长树的贪婪算法,在每个结点选取能最好地分类样例的属性。继续这个过程直到这棵树能完美分类训练样例,或所有的属性都使用过了。

ID3算法流程

ID3(Examples,Target_attribute,Attributes)
Examples 即训练样例集。Target_attribute 是这棵树要预测的目标属性。Attributes
是除目标属性外供学习到的决策树测试的属性列表。返回能正确分类给定
Examples 的决策树。
  创建树的 Root 结点
  如果 Examples 都为正,那么返回 label =+ 的单结点树 Root
 如果 Examples 都为反,那么返回 label =- 的单结点树 Root
  如果 Attributes 为空,那么返回单结点树 Root,label=Examples 中最普遍的
Target_attribute 值
  否则
  A←Attributes 中分类 Examples 能力最好*的属性
 Root 的决策属性←A
 对于A的每个可能值v
 在Root下加一个新的分支对应测试A= vi
 令Examples vi 为Examples中满足A属性值为v i的子集
 如果的子集Examples vi 为空在这个新分支下加一个叶子结点,结点的 label=Examples vi
中最普遍的 Target_attribute 值
 否则在这个新分支下加一个子树 ID3(Examples vi ,Target_attribute, Attributes-{A})
 结束
 返回 Root

其主要代码如下

 1 /**
 2      * 利用源数据构造决策树
 3      * @param node 正在处理处理的节点,
 4      * @param parentAttrValue父节点划分的属性
 5      */
 6     private void buildDecisionTree(AttrNode node, String parentAttrValue,
 7             String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
 8         node.setParentAttrValue(parentAttrValue);
 9 
10         String attrName = "";
11         double gainValue = 0;
12         double tempValue = 0;
13 
14         // 如果只有1个属性则直接返回
15         if (remainAttr.size() == 1) {
16             System.out.println("attr null");
17             return;
18         }
19 
20         // 选择剩余属性中信息增益最大的作为下一个分类的属性
21         for (int i = 0; i < remainAttr.size(); i++) {
22             // 判断是否用ID3算法还是C4.5算法
23             if (isID3) {
24                 // ID3算法采用的是按照信息增益的值来比
25                 tempValue = computeGain(remainData, remainAttr.get(i));
26             } else {
27                 // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
28                 tempValue = computeGainRatio(remainData, remainAttr.get(i));
29             }
30 
31             if (tempValue > gainValue) {
32                 gainValue = tempValue;
33                 attrName = remainAttr.get(i);
34             }
35         }
36 
37         node.setAttrName(attrName);
38         ArrayList<String> valueTypes = attrValue.get(attrName);
39         remainAttr.remove(attrName);//将选择的属性从剩余的属性集合中去除
40 
41         AttrNode[] childNode = new AttrNode[valueTypes.size()];
42         String[][] rData;
43         for (int i = 0; i < valueTypes.size(); i++) {
44             // 移除非此值类型的数据
45             rData = removeData(remainData, attrName, valueTypes.get(i));
46 
47             childNode[i] = new AttrNode();
48             boolean sameClass = true;
49             ArrayList<String> indexArray = new ArrayList<>();
50             for (int k = 1; k < rData.length; k++) {//rdata[0]保存的是attrName
51                 indexArray.add(rData[k][0]);//将编号统计进去 
52                 // 判断是否为同一类的,是否同为yes或者同为no
53                 if (!rData[k][attrNames.length - 1]
54                         .equals(rData[1][attrNames.length - 1])) {
55                     // 只要有1个不相等,就不是同类型的
56                     sameClass = false;
57                     break;
58                 }
59             }
60 
61             if (!sameClass) {
62                 // 创建新的对象属性,对象的同个引用会出错,rAttr是剩余的属性
63                 ArrayList<String> rAttr = new ArrayList<>();
64                 for (String str : remainAttr) {
65                     rAttr.add(str);
66                 }
67 
68                 buildDecisionTree(childNode[i], valueTypes.get(i), rData,
69                         rAttr, isID3);
70             } else {
71                 // 如果是同种类型,则直接为数据节点
72                 childNode[i].setParentAttrValue(valueTypes.get(i));
73                 childNode[i].setChildDataIndex(indexArray);
74             }
75 
76         }
77         node.setChildAttrNode(childNode);
78     }
View Code

计算信息增益

    /**
     * 为某个属性计算信息增益
     * 
     * @param remainData
     *            剩余的数据
     * @param value
     *            待划分的属性名称
     * @return
     */
    private double computeGain(String[][] remainData, String value) {
        double gainValue = 0;
        // 源熵的大小将会与属性划分后进行比较
        double entropyOri = 0;
        // 子划分熵和
        double childEntropySum = 0;
        // 属性子类型的个数
        int childValueNum = 0;
        // 属性值的种数
        ArrayList<String> attrTypes = attrValue.get(value);
        // 子属性对应的权重比
        HashMap<String, Integer> ratioValues = new HashMap<>();

        for (int i = 0; i < attrTypes.size(); i++) {
            // 首先都统一计数为0
            ratioValues.put(attrTypes.get(i), 0);
        }

        // 还是按照一列,从左往右遍历
        for (int j = 1; j < attrNames.length; j++) {
            // 判断是否到了划分的属性列
            if (value.equals(attrNames[j])) {
                for (int i = 1; i <= remainData.length - 1; i++) {
                    childValueNum = ratioValues.get(remainData[i][j]);
                    // 增加个数并且重新存入
                    childValueNum++;
                    ratioValues.put(remainData[i][j], childValueNum);
                }
            }
        }

        // 计算原熵的大小
        entropyOri = computeEntropy(remainData, value, null, true);
        for (int i = 0; i < attrTypes.size(); i++) {
            double ratio = (double) ratioValues.get(attrTypes.get(i))
                    / (remainData.length - 1);
            childEntropySum += ratio
                    * computeEntropy(remainData, value, attrTypes.get(i), false);

            // System.out.println("ratio:value: " + ratio + " " +
            // computeEntropy(remainData, value,
            // attrTypes.get(i), false));
        }

        // 二者熵相减就是信息增益
        gainValue = entropyOri - childEntropySum;
        return gainValue;
    }
View Code

若使用C4.5就会使用信息增益比

  1 /**
  2      * 计算信息增益比
  3      * 
  4      * @param remainData
  5      *            剩余数据
  6      * @param value
  7      *            待划分属性
  8      * @return
  9      */
 10     private double computeGainRatio(String[][] remainData, String value) {
 11         double gain = 0;
 12         double spiltInfo = 0;
 13         int childValueNum = 0;
 14         // 属性值的种数
 15         ArrayList<String> attrTypes = attrValue.get(value);
 16         // 子属性对应的权重比
 17         HashMap<String, Integer> ratioValues = new HashMap<>();
 18 
 19         for (int i = 0; i < attrTypes.size(); i++) {
 20             // 首先都统一计数为0
 21             ratioValues.put(attrTypes.get(i), 0);
 22         }
 23 
 24         // 还是按照一列,从左往右遍历
 25         for (int j = 1; j < attrNames.length; j++) {
 26             // 判断是否到了划分的属性列
 27             if (value.equals(attrNames[j])) {
 28                 for (int i = 1; i <= remainData.length - 1; i++) {
 29                     childValueNum = ratioValues.get(remainData[i][j]);
 30                     // 增加个数并且重新存入
 31                     childValueNum++;
 32                     ratioValues.put(remainData[i][j], childValueNum);
 33                 }
 34             }
 35         }
 36 
 37         // 计算信息增益
 38         gain = computeGain(remainData, value);
 39         // 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
 40         for (int i = 0; i < attrTypes.size(); i++) {
 41             double ratio = (double) ratioValues.get(attrTypes.get(i))
 42                     / (remainData.length - 1);
 43             spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0);
 44         }
 45 
 46         // 计算机信息增益率
 47         return gain / spiltInfo;
 48     }
 49 
 50     /**
 51      * 利用源数据构造决策树
 52      * @param node 正在处理处理的节点,
 53      * @param parentAttrValue父节点划分的属性
 54      */
 55     private void buildDecisionTree(AttrNode node, String parentAttrValue,
 56             String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
 57         node.setParentAttrValue(parentAttrValue);
 58 
 59         String attrName = "";
 60         double gainValue = 0;
 61         double tempValue = 0;
 62 
 63         // 如果只有1个属性则直接返回
 64         if (remainAttr.size() == 1) {
 65             System.out.println("attr null");
 66             return;
 67         }
 68 
 69         // 选择剩余属性中信息增益最大的作为下一个分类的属性
 70         for (int i = 0; i < remainAttr.size(); i++) {
 71             // 判断是否用ID3算法还是C4.5算法
 72             if (isID3) {
 73                 // ID3算法采用的是按照信息增益的值来比
 74                 tempValue = computeGain(remainData, remainAttr.get(i));
 75             } else {
 76                 // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
 77                 tempValue = computeGainRatio(remainData, remainAttr.get(i));
 78             }
 79 
 80             if (tempValue > gainValue) {
 81                 gainValue = tempValue;
 82                 attrName = remainAttr.get(i);
 83             }
 84         }
 85 
 86         node.setAttrName(attrName);
 87         ArrayList<String> valueTypes = attrValue.get(attrName);
 88         remainAttr.remove(attrName);//将选择的属性从剩余的属性集合中去除
 89 
 90         AttrNode[] childNode = new AttrNode[valueTypes.size()];
 91         String[][] rData;
 92         for (int i = 0; i < valueTypes.size(); i++) {
 93             // 移除非此值类型的数据
 94             rData = removeData(remainData, attrName, valueTypes.get(i));
 95 
 96             childNode[i] = new AttrNode();
 97             boolean sameClass = true;
 98             ArrayList<String> indexArray = new ArrayList<>();
 99             for (int k = 1; k < rData.length; k++) {//rdata[0]保存的是attrName
100                 indexArray.add(rData[k][0]);//将编号统计进去 
101                 // 判断是否为同一类的,是否同为yes或者同为no
102                 if (!rData[k][attrNames.length - 1]
103                         .equals(rData[1][attrNames.length - 1])) {
104                     // 只要有1个不相等,就不是同类型的
105                     sameClass = false;
106                     break;
107                 }
108             }
109 
110             if (!sameClass) {
111                 // 创建新的对象属性,对象的同个引用会出错,rAttr是剩余的属性
112                 ArrayList<String> rAttr = new ArrayList<>();
113                 for (String str : remainAttr) {
114                     rAttr.add(str);
115                 }
116 
117                 buildDecisionTree(childNode[i], valueTypes.get(i), rData,
118                         rAttr, isID3);
119             } else {
120                 // 如果是同种类型,则直接为数据节点
121                 childNode[i].setParentAttrValue(valueTypes.get(i));
122                 childNode[i].setChildDataIndex(indexArray);
123             }
124 
125         }
126         node.setChildAttrNode(childNode);
127     }
View Code

ID3 : 归纳偏置的更贴切近似:较短的树比较长的得到优先。那些信息增益高的属性
更靠近根结点的树得到优先

 

posted on 2015-03-19 22:12  未选择的路  阅读(720)  评论(0编辑  收藏  举报

导航