ID3算法

一、ID3算法简单介绍

最早起源于《罗斯昆ID3在悉尼大学。他第一次提出的ID3 1975年在一本书、机器学习、研究所硕士论文。ID3是建立了概念学习系统(CLS)算法。ID3算法是一种基于决策树的算法。决策树由决策结点、分支和叶子组成。决策树中最上面的结点为根节点,每个分支是一个新的决策结点,或者是树的叶子。每个决策结点代表一个问题或决策,通常对应于待分类对象的属性。每一个叶子节点代表一种可能的分类结果。沿决策树从上到下遍历的过程中,在每个结点都会遇到一个测试,对每个结点上问题的不同的测试输出导致不同的分支,最后会到达一个叶子节点,这个过程就是利用决策树进行分类的过程,利用若干个变量来判断所属的类别。

二、ID3算法基础--信息论

 ID3算法是一信息论为基础,以信息熵和信息增益度为衡量指标,从而实现对数据的分类操作。下面给出一些信息论中的基本概念:

定义1:若存在n个相同概率的消息,则每个消息的概率p是1/n,一个消息传递的消息量为-Log2(1/n)

定义2:若有n个消息,其给定概率分布为P=(p1,p2,....pn),则由该分布传递的消息量称为P的熵,记为I(P)=-p1*Log2(p1)-p2*Log2(p2)-...-pn*Log2(pn).

定义3:若一个记录集合T根据类别属性的值被分成互相独立的类C1,C2.....Ck,那么识别T的一个所属那个类型需要的信息量为Info(T)=I(P),其中P为C1,C2....Ck的概率分布,即P=(|C1|/|T|,|C2|/|T|,....|Ck|/|T|)。

定义4:若我们先根据非类别属性X的值将T的值分成集合T1,T2,T3....Tn,则确定T中一个元素类的信息量可通过确定Ti的加权平均值来得到,即Info(Ti)的加权平均值为:Info(X,T)=(i=1 to n求和)((|Ti|/|T|)Info(Ti))

定义5:信息增益度是两个信息量之间的差值,其中一个信息量是确定T的一个元素的信息量,另一个信息量是在得到一个确定属性X的值后需要确定T一个元素的信息量,公式为:Gain(X,T) = Info(T) = Info(X,T).

ID3算法计算每个属性的信息增益,并选择具有最高增益的属性作为给定集合的测试属性。对被选择的属性创建一个节点,并记录该节点的属性标记,对该属性的每一个值创建一个分支,并对分支进行迭代循环计算信息增益操作。

三、ID3算法步骤示例

 下面给定一个ID3算法的示例:

RID age income student credit_rating buy_compter
1 youth high no fair no
2 youth high no excellent no
3 middle_aged high no fair yes
4 senior medium no fair yes
5 senior low yes fair yes
6 senior low yes excellent no
7 middle_aged low yes excellent yes
8 youth medium no fair no
9 youth low yes fair yes
10 senior medium yes fair yes
11 youth medium yes excellent yes
12 middle_aged medium no excellent yes
13 middle_aged high yes fair yes
14 senior medium no excellent no

 

总数据量是14条,参考属性是age(youth[5], middle_aged[4], senior[5]),income(high[4], medium[6], low[4]), student(no[7], yes[7]), credit_rating(fair[8], excellent[6])。目标属性是bug_computer(no[5], yes[9]),希望的结果是能够得到一个根据age, income, student, credit_rating来推测出来buy_computer的值。假设初始数据集D,参考属性列表A,下面给定计算步骤:

第一步:在数据集D中就目标属性的信息熵: Info(buy_computer) = -(5/14)*log2(5/14)-(9/14)*log2(9/14)=0.94

第二步:在数据集D中就参考属性列表A中的每一个属性计算,在该属性值确定的条件下,确定一个bug_computer的信息熵,也就是条件熵。

  age属性:youth(no[3],yes[2]),middle_aged(no[0],yes[4]),senior(no[2],yes[3]),先分别计算youth、middle_aged、senior的信息熵。

    Infoage(bug_computer|youth) = -(3/5)*log2(3/5) - (2/5)*log2(2/5) = 0.971

    Infoage(bug_computer|middle_aged) = -(4/4)*log2(4/4) - (0/5)*log2(0/5) = 0

    Infoage(bug_computer|senior) = -(2/5)*log2(2/5) - (3/5)*log2(3/5) = 0.971

  则Infoage(buy_computer) = 5/14*0.971 + 4/14 * 0+ 5/14 * 0.971 = 0.694

  同理:Infoincome(buy_computer) = 0.911;Infostudent(buy_computer) = 0.789;Infocredit_rating(buy_computer) = 0.892.

第三步,计算信息增益度,该值如果越大,表示目标属性在该参考属性上失去的信息熵越多,那么该属性就越应该在决策树的上层。计算结果为:

  Gain(age,bug_computer) = Info(buy_computer) - Infoage(buy_computer) = 0.94 - 0.694 = 0.246

  Gain(income,bug_computer) = Info(buy_computer) - Infoicome(buy_computer) = 0.94 - 0.911 = 0.029

  Gain(student,bug_computer) = Info(buy_computer) - Infostudent(buy_computer) = 0.94 - 0.789 = 0.151

  Gain(credit_rating,bug_computer) = Info(buy_computer) - Infocredit_rating(buy_computer) = 0.94 - 0.892 = 0.048

第四步,选择信息增益度最大的属性作为当前节点,此时是age,根据age的不同取值将初始数据集D分隔成以下情况。

  1. age为youth的时候,子数据集是D1:

RID income student credit_rating buy_computer
1 high no fair no
2 high no excellent no
8 medium no fair no
9 low yes fair yes
11 high yes excellent yes

  2. age为middle_aged的时候,子数据集是D2:

RID income student credit_rating buy_computer
3 high no fair yes
7 low yes excellent yes
12 medium no excellent yes
13 high yes fair yes

  3. age为senior的时候,子数据集是D3:

RID income student credit_rating buy_computer
4 medium no fair yes
5 low yes fair yes
6 low yes excellent no
10 medium yes fair yes
14 medium no excellent no

 

第五步,将已经选择的参考属性(age)从参考属性列表A中剔除,针对第四步中产生的子数据集Di使用处理后的参考属性列表A,再从第一步迭代处理。迭代结束条件为:

  1. 当某种分类中,目标属性只有一个值,如这里当age为middle_aged的时候。
  2. 当分到某类的时候,目标属性所有值中,某个值的比例达到了阈值(人为控制),比如可以设为只要buy_computer中某个值达到90%以上,就可以结束迭代。

经过多次迭戈处理,最终会得到一个树结构如下图所示:

获得规则是:

IF AGE=middle_aged, THEN buy_computer = yes

IF AGE = youth AND STUDENT = yes, THEN buy_computer = yes

IF AGE = youth AND STUDENT = no, THEN buy_computer = no

IF AGE = senior AND CREDIT_RATING = excellent, THEN buy_computer = no

IF AGE = senior AND CREDIT_RATING = fair, THEN buy_computer = yes

SO, If the instance are ("15", "youth", "medium", "yes", "fair"), the predicted value of buy_computer is "yes".

 

四、ID3算法程序实现

下面分别给出python和java两种语言的ID3算法的实现:

Python程序:

  1 # -*- coding: utf-8 -*-
  2 
  3 
  4 class Node:
  5     '''Represents a decision tree node.
  6     
  7     '''
  8     def __init__(self, parent = None, dataset = None):
  9         self.dataset = dataset # 落在该结点的训练实例集
 10         self.result = None # 结果类标签
 11         self.attr = None # 该结点的分裂属性ID
 12         self.childs = {} # 该结点的子树列表,key-value pair: (属性attr的值, 对应的子树)
 13         self.parent = parent # 该结点的父亲结点
 14         
 15 
 16 
 17 def entropy(props):
 18     if (not isinstance(props, (tuple, list))):
 19         return None
 20     
 21     from math import log
 22     log2 = lambda x:log(x)/log(2) # an anonymous function
 23     e = 0.0
 24     for p in props:
 25         if p != 0:
 26             e = e - p * log2(p)
 27     return e
 28 
 29 
 30 def info_gain(D, A, T = -1, return_ratio = False):
 31     '''特征A对训练数据集D的信息增益 g(D,A)
 32     
 33     g(D,A)=entropy(D) - entropy(D|A)
 34             假设数据集D的每个元组的最后一个特征为类标签
 35     T为目标属性的ID,-1表示元组的最后一个元素为目标'''
 36     if (not isinstance(D, (set, list))):
 37         return None
 38     if (not type(A) is int):
 39         return None
 40     C = {} # 类别计数字典
 41     DA = {} # 特征A的取值计数字典
 42     CDA = {} # 类别和特征A的不同组合的取值计数字典
 43     for t in D:
 44         C[t[T]] = C.get(t[T], 0) + 1
 45         DA[t[A]] = DA.get(t[A], 0) + 1
 46         CDA[(t[T], t[A])] = CDA.get((t[T], t[A]), 0) + 1
 47 
 48     PC = map(lambda x : 1.0 * x / len(D), C.values()) # 类别的概率列表,即目标属性的概率,信息熵
 49     entropy_D = entropy(tuple(PC)) # map返回的对象类型为map,需要强制类型转换为元组
 50 
 51 
 52     PCDA = {} # 特征A的每个取值给定的条件下各个类别的概率(条件概率)
 53     for key, value in CDA.items():
 54         a = key[1] # 特征A
 55         pca = value / DA[a]
 56         PCDA.setdefault(a, []).append(pca)
 57     
 58     condition_entropy = 0.0
 59     for a, v in DA.items():
 60         p = v / len(D)
 61         e = entropy(PCDA[a])
 62         condition_entropy += e * p
 63     
 64     if (return_ratio):
 65         return (entropy_D - condition_entropy) / entropy_D
 66     else:
 67         return entropy_D - condition_entropy
 68     
 69 def get_result(D, T = -1):
 70     '''获取数据集D中实例数最大的目标特征T的值'''
 71     if (not isinstance(D, (set, list))):
 72         return None
 73     if (not type(T) is int):
 74         return None
 75     count = {}
 76     for t in D:
 77         count[t[T]] = count.get(t[T], 0) + 1
 78     max_count = 0
 79     for key, value in count.items():
 80         if (value > max_count):
 81             max_count = value
 82             result = key
 83     return result 
 84 
 85 
 86 def devide_set(D, A):
 87     '''根据特征A的值把数据集D分裂为多个子集'''
 88     if (not isinstance(D, (set, list))):
 89         return None
 90     if (not type(A) is int):
 91         return None
 92     subset = {}
 93     for t in D:
 94         subset.setdefault(t[A], []).append(t)
 95     return subset
 96 
 97 
 98 def build_tree(D, A, threshold = 0.0001, T = -1, Tree = None, algo = "ID3"):
 99     '''根据数据集D和特征集A构建决策树.
100     
101     T为目标属性在元组中的索引 . 目前支持ID3和C4.5两种算法'''
102     if (Tree != None and not isinstance(Tree, Node)):
103         return None
104     if (not isinstance(D, (set, list))):
105         return None
106     if (not type(A) is set):
107         return None
108     
109     if (None == Tree):
110         Tree = Node(None, D)
111     subset = devide_set(D, T)
112     if (len(subset) <= 1):
113         for key in subset.keys():
114             Tree.result = key
115         del(subset)
116         return Tree
117     if (len(A) <= 0):
118         Tree.result = get_result(D)
119         return Tree
120     use_gain_ratio = False if algo == "ID3" else True
121 
122     max_gain = 0
123     for a in A:
124         gain = info_gain(D, a, return_ratio = use_gain_ratio)
125         if (gain > max_gain):
126             max_gain = gain
127             attr_id = a # 获取信息增益最大的特征
128     if (max_gain < threshold):
129         Tree.result = get_result(D)
130         return Tree
131     Tree.attr = attr_id
132     subD = devide_set(D, attr_id)
133     del(D[:]) # 删除中间数据,释放内存
134     Tree.dataset = None
135     A.discard(attr_id) # 从特征集中排查已经使用过的特征
136     for key in subD.keys():
137         tree = Node(Tree, subD.get(key))
138         Tree.childs[key] = tree
139         build_tree(subD.get(key), A, threshold, T, tree)
140     return Tree
141 
142 
143 def print_brance(brance, target):
144     odd = 0
145     for e in brance:
146         print e,('=' if odd == 0 else ''),
147         odd = 1 - odd
148     print "target =", target
149 
150 
151 def print_tree(Tree, stack = []): 
152     if (None == Tree):
153         return
154     if (None != Tree.result):
155         print_brance(stack, Tree.result)
156         return
157     stack.append(Tree.attr)
158     for key, value in Tree.childs.items():
159         stack.append(key)
160         print_tree(value, stack)
161         stack.pop()
162     stack.pop()
163     
164 def classify(Tree, instance):
165     if (None == Tree):
166         return None
167     if (None != Tree.result):
168         return Tree.result
169     if instance[Tree.attr] in Tree.childs:
170         return classify(Tree.childs[instance[Tree.attr]], instance)
171     else:
172         return None
173 
174 dataset = [
175    ("青年", "", "", "一般", "")
176    ,("青年", "", "", "", "")
177    ,("青年", "", "", "", "")
178    ,("青年", "", "", "一般", "")
179    ,("青年", "", "", "一般", "")
180    ,("中年", "", "", "一般", "")
181    ,("中年", "", "", "", "")
182    ,("老年", "", "", "非常好", "")
183    ,("老年", "", "", "一般", "")
184    ,("老年", "", "", "一般", "")
185    ,("老年", "", "", "一般", "")
186    ,("老年", "", "", "", "")
187    ,("老年", "", "", "一般", "")
188    ,("老年", "", "", "一般", "")
189    ,("老年", "", "", "一般", "")
190 ]
191 
192 s = set(range(0, len(dataset[0]) - 1))
193 s = set([0,1,3,4])
194 T = build_tree(dataset, s)
195 print_tree(T)
196 print(classify(T, ("老年", "", "", "一般", "")))
197 print(classify(T, ("老年", "", "", "一般", "")))
198 print(classify(T, ("老年", "", "", "", "")))
199 print(classify(T, ("青年", "", "", "", "")))
200 print(classify(T, ("中年", "", "", "", "")))
ID3--Python

 

该python程序的训练集不是上面给定的这个列子,输出结果为:

0 = 青年 ∧ 1 = 否 ∧ target = 否
0 = 青年 ∧ 1 = 是 ∧ target = 是
0 = 中年 ∧ target = 否
0 = 老年 ∧ 3 = 好 ∧ target = 是
0 = 老年 ∧ 3 = 非常好 ∧ target = 是
0 = 老年 ∧ 3 = 一般 ∧ 4 = 否 ∧ target = 否
0 = 老年 ∧ 3 = 一般 ∧ 4 = 是 ∧ target = 是
否
是
是
是
否
[Finished in 0.3s]

  

Java程序,该程序的数据集是上面给定的例子,代码及结果如下:

  1   2 
  3 import java.util.ArrayList;
  4 import java.util.Collection;
  5 import java.util.Deque;
  6 import java.util.HashMap;
  7 import java.util.LinkedList;
  8 import java.util.List;
  9 import java.util.Map;
 10 
 11 public class ID3Tree {
 12     private List<String[]> datas;
 13     private List<Integer> attributes;
 14     private double threshold = 0.0001;
 15     private int targetIndex = 1;
 16     private Node tree;
 17     private Map<Integer, String> attributeMap;
 18 
 19     protected ID3Tree() {
 20         super();
 21     }
 22 
 23     public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, int targetIndex) {
 24         this(datas, attributes, attributeMap, 0.0001, targetIndex, null);
 25     }
 26 
 27     public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, double threshold, int targetIndex, Node tree) {
 28         super();
 29         this.datas = datas;
 30         this.attributes = attributes;
 31         this.attributeMap = attributeMap;
 32         this.threshold = threshold;
 33         this.targetIndex = targetIndex;
 34         this.tree = tree;
 35     }
 36 
 37     /**
 38      * 节点对象
 39      * 
 40      * @author Gerry.Liu
 41      *
 42      */
 43     class Node {
 44         private List<String[]> dataset; // 落在该节点上的训练实训集
 45         private String result; // 结果类标签
 46         private int attr; // 该节点的分裂属性ID,下标
 47         private Node parent; // 该节点的父节点
 48         private Map<String, List<Node>> childs; // 该节点的子节点集合
 49 
 50         public Node(List<String[]> datas, Node parent) {
 51             this.dataset = datas;
 52             this.parent = parent;
 53             this.childs = new HashMap<>();
 54         }
 55     }
 56 
 57     class KeyValue {
 58         private String first;
 59         private String second;
 60 
 61         public KeyValue(String first, String second) {
 62             super();
 63             this.first = first;
 64             this.second = second;
 65         }
 66 
 67         @Override
 68         public int hashCode() {
 69             final int prime = 31;
 70             int result = 1;
 71             result = prime * result + getOuterType().hashCode();
 72             result = prime * result + ((first == null) ? 0 : first.hashCode());
 73             result = prime * result + ((second == null) ? 0 : second.hashCode());
 74             return result;
 75         }
 76 
 77         @Override
 78         public boolean equals(Object obj) {
 79             if (this == obj)
 80                 return true;
 81             if (obj == null)
 82                 return false;
 83             if (getClass() != obj.getClass())
 84                 return false;
 85             KeyValue other = (KeyValue) obj;
 86             if (!getOuterType().equals(other.getOuterType()))
 87                 return false;
 88             if (first == null) {
 89                 if (other.first != null)
 90                     return false;
 91             } else if (!first.equals(other.first))
 92                 return false;
 93             if (second == null) {
 94                 if (other.second != null)
 95                     return false;
 96             } else if (!second.equals(other.second))
 97                 return false;
 98             return true;
 99         }
100 
101         private ID3Tree getOuterType() {
102             return ID3Tree.this;
103         }
104     }
105 
106     /**
107      * 根据概率计算信息熵,计算规则是:<br/>
108      * entropy(p1,p2....pn) = -p1*log2(p1) -p2*log2(p2)-.....-pn*log2(pn)
109      *
110      * @param props
111      * @return
112      */
113     private double entropy(List<Double> props) {
114         if (props == null || props.isEmpty()) {
115             return 0;
116         } else {
117             double result = 0;
118             for (double p : props) {
119                 if (p > 0) {
120                     result = result - p * Math.log(p) / Math.log(2);
121                 }
122             }
123             return result;
124         }
125     }
126 
127     /**
128      * 计算概率
129      * 
130      * @param totalRecords
131      * @param counts
132      * @return
133      */
134     private List<Double> calcProbability(int totalRecords, Collection<Integer> counts) {
135         if (totalRecords == 0 || counts == null || counts.isEmpty()) {
136             return null;
137         }
138 
139         List<Double> result = new ArrayList<>();
140         for (int count : counts) {
141             result.add(1.0 * count / totalRecords);
142         }
143         return result;
144     }
145 
146     /**
147      * 获取信息增益Gain(datas,attribute)<br/>
148      * 特征属性attribute(A)对训练数据集datas(D)的信息增益<br/>
149      * g(D,A) = entropy(D) - entropy(D|A)<br/>
150      * 
151      * @param datas
152      *            训练数据集
153      * @param attributeIndex
154      *            特征属性下标
155      * @param targetAttributeIndex
156      *            目标属性下标
157      * @return
158      */
159     private double infoGain(List<String[]> datas, int attributeIndex, int targetAttributeIndex) {
160         if (datas == null || datas.isEmpty()) {
161             return 0;
162         }
163 
164         Map<String, Integer> targetAttributeCountMap = new HashMap<String, Integer>(); // 类别(目标属性)计数
165         Map<String, Integer> featureAttributesCountMap = new HashMap<>(); // 特征属性上的取值计数
166         Map<KeyValue, Integer> tfAttributeCountMap = new HashMap<>(); // 类别和特征属性的不同组合的计数
167 
168         for (String[] arrs : datas) {
169             String tv = arrs[targetAttributeIndex];
170             String fv = arrs[attributeIndex];
171             if (targetAttributeCountMap.containsKey(tv)) {
172                 targetAttributeCountMap.put(tv, targetAttributeCountMap.get(tv) + 1);
173             } else {
174                 targetAttributeCountMap.put(tv, 1);
175             }
176             if (featureAttributesCountMap.containsKey(fv)) {
177                 featureAttributesCountMap.put(fv, featureAttributesCountMap.get(fv) + 1);
178             } else {
179                 featureAttributesCountMap.put(fv, 1);
180             }
181             KeyValue key = new KeyValue(fv, tv);
182             if (tfAttributeCountMap.containsKey(key)) {
183                 tfAttributeCountMap.put(key, tfAttributeCountMap.get(key) + 1);
184             } else {
185                 tfAttributeCountMap.put(key, 1);
186             }
187         }
188 
189         int totalDataSize = datas.size();
190         // 计算概率
191         List<Double> probabilitys = calcProbability(totalDataSize, targetAttributeCountMap.values());
192         // 计算目标属性的信息熵
193         double entropyDatas = this.entropy(probabilitys);
194 
195         // 计算条件概率
196         // 第一步,计算目标属性的各种取值,在特征属性确定的条件下的情况
197         Map<String, List<Double>> pcda = new HashMap<>();
198         for (Map.Entry<KeyValue, Integer> entry : tfAttributeCountMap.entrySet()) {
199             String key = entry.getKey().first;
200             double pca = 1.0 * entry.getValue() / featureAttributesCountMap.get(key);
201             if (pcda.containsKey(key)) {
202                 pcda.get(key).add(pca);
203             } else {
204                 List<Double> list = new ArrayList<Double>();
205                 list.add(pca);
206                 pcda.put(key, list);
207             }
208         }
209         // 第二步,针对每个特征属性的值取信息熵,并获取平均熵
210         double conditionEntropy = 0.0;
211         for (Map.Entry<String, Integer> entry : featureAttributesCountMap.entrySet()) {
212             double p = 1.0 * entry.getValue() / totalDataSize;
213             double e = this.entropy(pcda.get(entry.getKey()));
214             conditionEntropy += e * p;
215         }
216         return entropyDatas - conditionEntropy;
217     }
218 
219     /**
220      * 获取数据集中目标属性中,实例值个数最大的目标特征值
221      * 
222      * @param datas
223      * @param targetAttributeIndex
224      * @return
225      */
226     private String getResult(List<String[]> datas, int targetAttributeIndex) {
227         if (datas == null || datas.isEmpty()) {
228             return null;
229         } else {
230             String result = "";
231             Map<String, Integer> countMap = new HashMap<>();
232             for (String[] arr : datas) {
233                 String key = arr[targetAttributeIndex];
234                 if (countMap.containsKey(key)) {
235                     countMap.put(key, countMap.get(key) + 1);
236                 } else {
237                     countMap.put(key, 1);
238                 }
239             }
240 
241             int maxCount = -1;
242             for (Map.Entry<String, Integer> entry : countMap.entrySet()) {
243                 if (entry.getValue() > maxCount) {
244                     maxCount = entry.getValue();
245                     result = entry.getKey();
246                 }
247             }
248             return result;
249         }
250     }
251 
252     /**
253      * 按照特征属性的值将数据集D分裂成为多个子集
254      * 
255      * @param datas
256      *            数据集
257      * @param attributeIndex
258      *            特征属性下标
259      * @return
260      */
261     private Map<String, List<String[]>> devideDatas(List<String[]> datas, int attributeIndex) {
262         Map<String, List<String[]>> subdatas = new HashMap<>();
263         if (datas != null && !datas.isEmpty()) {
264             for (String[] arr : datas) {
265                 String key = arr[attributeIndex];
266                 if (subdatas.containsKey(key)) {
267                     subdatas.get(key).add(arr);
268                 } else {
269                     List<String[]> list = new ArrayList<>();
270                     list.add(arr);
271                     subdatas.put(key, list);
272                 }
273             }
274         }
275         return subdatas;
276     }
277 
278     /**
279      * 打印决策树
280      * 
281      * @param tree
282      * @param stock
283      */
284     private void printTree(Node tree, Deque<Object> stock) {
285         if (tree == null) {
286             return;
287         }
288 
289         if (tree.result != null) {
290             this.printBrance(stock, tree.result);
291         } else {
292             stock.push(this.attributeMap.get(tree.attr));
293             for (Map.Entry<String, List<Node>> entry : tree.childs.entrySet()) {
294                 stock.push(entry.getKey());
295                 for (Node node : entry.getValue()) {
296                     this.printTree(node, stock);
297                 }
298                 stock.pop();
299             }
300             stock.pop();
301         }
302     }
303 
304     /**
305      * 输出Node表示的决策树的规则
306      * 
307      * @param tree
308      */
309     private void printBrance(Deque<Object> stock, String target) {
310         StringBuffer sb = new StringBuffer();
311         int odd = 0;
312         for (Object e : stock) {
313             sb.insert(0, odd == 0 ? "^" : "=").insert(0, e);
314             // sb.append(e).append(odd == 0 ? "=" : "^");
315             odd = 1 - odd;
316         }
317         sb.append("target=").append(target);
318         System.out.println(sb.toString());
319     }
320 
321     /**
322      * 创建一个决策树
323      * 
324      * @param datas
325      * @param attributes
326      * @param threshold
327      * @param targetIndex
328      * @param tree
329      * @return
330      */
331     private Node buildTree(List<String[]> datas, List<Integer> attributes, double threshold, int targetIndex, Node tree) {
332         if (tree == null) {
333             tree = new Node(datas, null);
334         }
335         // 分隔数据集,返回的数据集为empty或者是有数据,不会为null
336         Map<String, List<String[]>> subDatas = this.devideDatas(datas, targetIndex);
337         if (subDatas.size() <= 1) {
338             // 这里只会有一个key
339             for (String key : subDatas.keySet()) {
340                 tree.result = key;
341             }
342         } else if (attributes == null || attributes.size() < 1) {
343             // 没有特征集,那么直接获取最多的值
344             tree.result = this.getResult(datas, targetIndex);
345         } else {
346             double maxGain = 0;
347             int attr = 0;
348 
349             for (int attribute : attributes) {
350                 double gain = this.infoGain(datas, attribute, targetIndex);
351                 if (gain > maxGain) {
352                     maxGain = gain;
353                     attr = attribute;// 最大的信息增益下标
354                 }
355             }
356 
357             if (maxGain < threshold) {
358                 // 达到收益条件
359                 tree.result = this.getResult(datas, targetIndex);
360             } else {
361                 // 没有达到结束条件,继续进行
362                 tree.attr = attr;
363                 subDatas = this.devideDatas(datas, attr);
364                 tree.dataset = null;
365                 attributes.remove(Integer.valueOf(attr));
366                 for (String key : subDatas.keySet()) {
367                     Node childTree = new Node(subDatas.get(key), tree);
368                     if (tree.childs.containsKey(key)) {
369                         tree.childs.get(key).add(childTree);
370                     } else {
371                         List<Node> childs = new ArrayList<>();
372                         childs.add(childTree);
373                         tree.childs.put(key, childs);
374                     }
375                     this.buildTree(subDatas.get(key), attributes, threshold, targetIndex, childTree);
376                 }
377             }
378         }
379         return tree;
380     }
381 
382     /**
383      * 根据决策规则获取推荐值
384      * 
385      * @param instance
386      * @return
387      */
388     private String classify(Node tree, String[] instance) {
389         if (tree == null) {
390             return null;
391         }
392         if (tree.result != null) {
393             return tree.result;
394         }
395         if (tree.childs.containsKey(instance[tree.attr])) {
396             List<Node> nodes = tree.childs.get(instance[tree.attr]);
397             for (Node node : nodes) {
398                 return this.classify(node, instance);
399             }
400         }
401         return null;
402     }
403 
404     /**
405      * 生产决策树
406      */
407     public void buildTree() {
408         this.tree = new Node(this.datas, null);
409         this.buildTree(datas, attributes, threshold, targetIndex, tree);
410     }
411 
412     /**
413      * 打印生产的规则
414      */
415     public void printTree() {
416         this.printTree(this.tree, new LinkedList<>());
417     }
418 
419     /**
420      * 获取推荐结果
421      * 
422      * @param instance
423      * @return
424      */
425     public String classify(String[] instance) {
426         return this.classify(this.tree, instance);
427     }
428 
429     public static void main(String[] args) {
430         List<String[]> dataset = new ArrayList<>();
431         dataset.add(new String[] { "1", "youth", "high", "no", "fair", "no" });
432         dataset.add(new String[] { "2", "youth", "high", "no", "excellent", "no" });
433         dataset.add(new String[] { "3", "middle_aged", "high", "no", "fair", "yes" });
434         dataset.add(new String[] { "4", "senior", "medium", "no", "fair", "yes" });
435         dataset.add(new String[] { "5", "senior", "low", "yes", "fair", "yes" });
436         dataset.add(new String[] { "6", "senior", "low", "yes", "excellent", "no" });
437         dataset.add(new String[] { "7", "middle_aged", "low", "yes", "excellent", "yes" });
438         dataset.add(new String[] { "8", "youth", "medium", "no", "fair", "no" });
439         dataset.add(new String[] { "9", "youth", "low", "yes", "fair", "yes" });
440         dataset.add(new String[] { "10", "senior", "medium", "yes", "fair", "yes" });
441         dataset.add(new String[] { "11", "youth", "medium", "yes", "excellent", "yes" });
442         dataset.add(new String[] { "12", "middle_aged", "medium", "no", "excellent", "yes" });
443         dataset.add(new String[] { "13", "middle_aged", "high", "yes", "fair", "yes" });
444         dataset.add(new String[] { "14", "senior", "medium", "no", "excellent", "no" });
445 
446         List<Integer> attributes = new ArrayList<>();
447         attributes.add(4);
448         attributes.add(1);
449         attributes.add(2);
450         attributes.add(3);
451 
452         Map<Integer, String> attributeMap = new HashMap<>();
453         attributeMap.put(1, "Age");
454         attributeMap.put(2, "Income");
455         attributeMap.put(3, "Student");
456         attributeMap.put(4, "Credit_rating");
457 
458         int targetIndex = 5;
459 
460         String[] instance = new String[] { "15", "youth", "medium", "yes", "fair" };
461 
462         ID3Tree tree = new ID3Tree(dataset, attributes,attributeMap, targetIndex);
463         System.out.println("start build the tree");
464         tree.buildTree();
465         System.out.println("completed build the tree, start print the tree");
466         tree.printTree();
467         System.out.println("start classify.....");
468         String result = tree.classify(instance);
469         System.out.println(result);
470     }
471 }
ID3--Java

 运行java程序的结果是:

Start build the tree.....
Completed build the tree, start print the tree.....
Age=youth^Student=yes^target=yes
Age=youth^Student=no^target=no
Age=middle_aged^target=yes
Age=senior^Credit_rating=excellent^target=no
Age=senior^Credit_rating=fair^target=yes
start classify.....
yes

 五、ID3算法不足 

ID3算法运行速度较慢,只能加载内存中的数据,处理的数据集相对于其他算法较小。

 

posted @ 2015-01-28 18:36  liuming_1992  阅读(9390)  评论(0编辑  收藏  举报