ID3算法
一、ID3算法简单介绍
最早起源于《罗斯昆ID3在悉尼大学。他第一次提出的ID3 1975年在一本书、机器学习、研究所硕士论文。ID3是建立了概念学习系统(CLS)算法。ID3算法是一种基于决策树的算法。决策树由决策结点、分支和叶子组成。决策树中最上面的结点为根节点,每个分支是一个新的决策结点,或者是树的叶子。每个决策结点代表一个问题或决策,通常对应于待分类对象的属性。每一个叶子节点代表一种可能的分类结果。沿决策树从上到下遍历的过程中,在每个结点都会遇到一个测试,对每个结点上问题的不同的测试输出导致不同的分支,最后会到达一个叶子节点,这个过程就是利用决策树进行分类的过程,利用若干个变量来判断所属的类别。
二、ID3算法基础--信息论
ID3算法是一信息论为基础,以信息熵和信息增益度为衡量指标,从而实现对数据的分类操作。下面给出一些信息论中的基本概念:
定义1:若存在n个相同概率的消息,则每个消息的概率p是1/n,一个消息传递的消息量为-Log2(1/n)
定义2:若有n个消息,其给定概率分布为P=(p1,p2,....pn),则由该分布传递的消息量称为P的熵,记为I(P)=-p1*Log2(p1)-p2*Log2(p2)-...-pn*Log2(pn).
定义3:若一个记录集合T根据类别属性的值被分成互相独立的类C1,C2.....Ck,那么识别T的一个所属那个类型需要的信息量为Info(T)=I(P),其中P为C1,C2....Ck的概率分布,即P=(|C1|/|T|,|C2|/|T|,....|Ck|/|T|)。
定义4:若我们先根据非类别属性X的值将T的值分成集合T1,T2,T3....Tn,则确定T中一个元素类的信息量可通过确定Ti的加权平均值来得到,即Info(Ti)的加权平均值为:Info(X,T)=(i=1 to n求和)((|Ti|/|T|)Info(Ti))
定义5:信息增益度是两个信息量之间的差值,其中一个信息量是确定T的一个元素的信息量,另一个信息量是在得到一个确定属性X的值后需要确定T一个元素的信息量,公式为:Gain(X,T) = Info(T) = Info(X,T).
ID3算法计算每个属性的信息增益,并选择具有最高增益的属性作为给定集合的测试属性。对被选择的属性创建一个节点,并记录该节点的属性标记,对该属性的每一个值创建一个分支,并对分支进行迭代循环计算信息增益操作。
三、ID3算法步骤示例
下面给定一个ID3算法的示例:
RID | age | income | student | credit_rating | buy_compter |
1 | youth | high | no | fair | no |
2 | youth | high | no | excellent | no |
3 | middle_aged | high | no | fair | yes |
4 | senior | medium | no | fair | yes |
5 | senior | low | yes | fair | yes |
6 | senior | low | yes | excellent | no |
7 | middle_aged | low | yes | excellent | yes |
8 | youth | medium | no | fair | no |
9 | youth | low | yes | fair | yes |
10 | senior | medium | yes | fair | yes |
11 | youth | medium | yes | excellent | yes |
12 | middle_aged | medium | no | excellent | yes |
13 | middle_aged | high | yes | fair | yes |
14 | senior | medium | no | excellent | no |
总数据量是14条,参考属性是age(youth[5], middle_aged[4], senior[5]),income(high[4], medium[6], low[4]), student(no[7], yes[7]), credit_rating(fair[8], excellent[6])。目标属性是bug_computer(no[5], yes[9]),希望的结果是能够得到一个根据age, income, student, credit_rating来推测出来buy_computer的值。假设初始数据集D,参考属性列表A,下面给定计算步骤:
第一步:在数据集D中就目标属性的信息熵: Info(buy_computer) = -(5/14)*log2(5/14)-(9/14)*log2(9/14)=0.94
第二步:在数据集D中就参考属性列表A中的每一个属性计算,在该属性值确定的条件下,确定一个bug_computer的信息熵,也就是条件熵。
age属性:youth(no[3],yes[2]),middle_aged(no[0],yes[4]),senior(no[2],yes[3]),先分别计算youth、middle_aged、senior的信息熵。
Infoage(bug_computer|youth) = -(3/5)*log2(3/5) - (2/5)*log2(2/5) = 0.971
Infoage(bug_computer|middle_aged) = -(4/4)*log2(4/4) - (0/5)*log2(0/5) = 0
Infoage(bug_computer|senior) = -(2/5)*log2(2/5) - (3/5)*log2(3/5) = 0.971
则Infoage(buy_computer) = 5/14*0.971 + 4/14 * 0+ 5/14 * 0.971 = 0.694
同理:Infoincome(buy_computer) = 0.911;Infostudent(buy_computer) = 0.789;Infocredit_rating(buy_computer) = 0.892.
第三步,计算信息增益度,该值如果越大,表示目标属性在该参考属性上失去的信息熵越多,那么该属性就越应该在决策树的上层。计算结果为:
Gain(age,bug_computer) = Info(buy_computer) - Infoage(buy_computer) = 0.94 - 0.694 = 0.246
Gain(income,bug_computer) = Info(buy_computer) - Infoicome(buy_computer) = 0.94 - 0.911 = 0.029
Gain(student,bug_computer) = Info(buy_computer) - Infostudent(buy_computer) = 0.94 - 0.789 = 0.151
Gain(credit_rating,bug_computer) = Info(buy_computer) - Infocredit_rating(buy_computer) = 0.94 - 0.892 = 0.048
第四步,选择信息增益度最大的属性作为当前节点,此时是age,根据age的不同取值将初始数据集D分隔成以下情况。
1. age为youth的时候,子数据集是D1:
RID | income | student | credit_rating | buy_computer |
1 | high | no | fair | no |
2 | high | no | excellent | no |
8 | medium | no | fair | no |
9 | low | yes | fair | yes |
11 | high | yes | excellent | yes |
2. age为middle_aged的时候,子数据集是D2:
RID | income | student | credit_rating | buy_computer |
3 | high | no | fair | yes |
7 | low | yes | excellent | yes |
12 | medium | no | excellent | yes |
13 | high | yes | fair | yes |
3. age为senior的时候,子数据集是D3:
RID | income | student | credit_rating | buy_computer |
4 | medium | no | fair | yes |
5 | low | yes | fair | yes |
6 | low | yes | excellent | no |
10 | medium | yes | fair | yes |
14 | medium | no | excellent | no |
第五步,将已经选择的参考属性(age)从参考属性列表A中剔除,针对第四步中产生的子数据集Di使用处理后的参考属性列表A,再从第一步迭代处理。迭代结束条件为:
- 当某种分类中,目标属性只有一个值,如这里当age为middle_aged的时候。
- 当分到某类的时候,目标属性所有值中,某个值的比例达到了阈值(人为控制),比如可以设为只要buy_computer中某个值达到90%以上,就可以结束迭代。
经过多次迭戈处理,最终会得到一个树结构如下图所示:
获得规则是:
IF AGE=middle_aged, THEN buy_computer = yes
IF AGE = youth AND STUDENT = yes, THEN buy_computer = yes
IF AGE = youth AND STUDENT = no, THEN buy_computer = no
IF AGE = senior AND CREDIT_RATING = excellent, THEN buy_computer = no
IF AGE = senior AND CREDIT_RATING = fair, THEN buy_computer = yes
SO, If the instance are ("15", "youth", "medium", "yes", "fair"), the predicted value of buy_computer is "yes".
四、ID3算法程序实现
下面分别给出python和java两种语言的ID3算法的实现:
Python程序:
1 # -*- coding: utf-8 -*- 2 3 4 class Node: 5 '''Represents a decision tree node. 6 7 ''' 8 def __init__(self, parent = None, dataset = None): 9 self.dataset = dataset # 落在该结点的训练实例集 10 self.result = None # 结果类标签 11 self.attr = None # 该结点的分裂属性ID 12 self.childs = {} # 该结点的子树列表,key-value pair: (属性attr的值, 对应的子树) 13 self.parent = parent # 该结点的父亲结点 14 15 16 17 def entropy(props): 18 if (not isinstance(props, (tuple, list))): 19 return None 20 21 from math import log 22 log2 = lambda x:log(x)/log(2) # an anonymous function 23 e = 0.0 24 for p in props: 25 if p != 0: 26 e = e - p * log2(p) 27 return e 28 29 30 def info_gain(D, A, T = -1, return_ratio = False): 31 '''特征A对训练数据集D的信息增益 g(D,A) 32 33 g(D,A)=entropy(D) - entropy(D|A) 34 假设数据集D的每个元组的最后一个特征为类标签 35 T为目标属性的ID,-1表示元组的最后一个元素为目标''' 36 if (not isinstance(D, (set, list))): 37 return None 38 if (not type(A) is int): 39 return None 40 C = {} # 类别计数字典 41 DA = {} # 特征A的取值计数字典 42 CDA = {} # 类别和特征A的不同组合的取值计数字典 43 for t in D: 44 C[t[T]] = C.get(t[T], 0) + 1 45 DA[t[A]] = DA.get(t[A], 0) + 1 46 CDA[(t[T], t[A])] = CDA.get((t[T], t[A]), 0) + 1 47 48 PC = map(lambda x : 1.0 * x / len(D), C.values()) # 类别的概率列表,即目标属性的概率,信息熵 49 entropy_D = entropy(tuple(PC)) # map返回的对象类型为map,需要强制类型转换为元组 50 51 52 PCDA = {} # 特征A的每个取值给定的条件下各个类别的概率(条件概率) 53 for key, value in CDA.items(): 54 a = key[1] # 特征A 55 pca = value / DA[a] 56 PCDA.setdefault(a, []).append(pca) 57 58 condition_entropy = 0.0 59 for a, v in DA.items(): 60 p = v / len(D) 61 e = entropy(PCDA[a]) 62 condition_entropy += e * p 63 64 if (return_ratio): 65 return (entropy_D - condition_entropy) / entropy_D 66 else: 67 return entropy_D - condition_entropy 68 69 def get_result(D, T = -1): 70 '''获取数据集D中实例数最大的目标特征T的值''' 71 if (not isinstance(D, (set, list))): 72 return None 73 if (not type(T) is int): 74 return None 75 count = {} 76 for t in D: 77 count[t[T]] = count.get(t[T], 0) + 1 78 max_count = 0 79 for key, value in count.items(): 80 if (value > max_count): 81 max_count = value 82 result = key 83 return result 84 85 86 def devide_set(D, A): 87 '''根据特征A的值把数据集D分裂为多个子集''' 88 if (not isinstance(D, (set, list))): 89 return None 90 if (not type(A) is int): 91 return None 92 subset = {} 93 for t in D: 94 subset.setdefault(t[A], []).append(t) 95 return subset 96 97 98 def build_tree(D, A, threshold = 0.0001, T = -1, Tree = None, algo = "ID3"): 99 '''根据数据集D和特征集A构建决策树. 100 101 T为目标属性在元组中的索引 . 目前支持ID3和C4.5两种算法''' 102 if (Tree != None and not isinstance(Tree, Node)): 103 return None 104 if (not isinstance(D, (set, list))): 105 return None 106 if (not type(A) is set): 107 return None 108 109 if (None == Tree): 110 Tree = Node(None, D) 111 subset = devide_set(D, T) 112 if (len(subset) <= 1): 113 for key in subset.keys(): 114 Tree.result = key 115 del(subset) 116 return Tree 117 if (len(A) <= 0): 118 Tree.result = get_result(D) 119 return Tree 120 use_gain_ratio = False if algo == "ID3" else True 121 122 max_gain = 0 123 for a in A: 124 gain = info_gain(D, a, return_ratio = use_gain_ratio) 125 if (gain > max_gain): 126 max_gain = gain 127 attr_id = a # 获取信息增益最大的特征 128 if (max_gain < threshold): 129 Tree.result = get_result(D) 130 return Tree 131 Tree.attr = attr_id 132 subD = devide_set(D, attr_id) 133 del(D[:]) # 删除中间数据,释放内存 134 Tree.dataset = None 135 A.discard(attr_id) # 从特征集中排查已经使用过的特征 136 for key in subD.keys(): 137 tree = Node(Tree, subD.get(key)) 138 Tree.childs[key] = tree 139 build_tree(subD.get(key), A, threshold, T, tree) 140 return Tree 141 142 143 def print_brance(brance, target): 144 odd = 0 145 for e in brance: 146 print e,('=' if odd == 0 else '∧'), 147 odd = 1 - odd 148 print "target =", target 149 150 151 def print_tree(Tree, stack = []): 152 if (None == Tree): 153 return 154 if (None != Tree.result): 155 print_brance(stack, Tree.result) 156 return 157 stack.append(Tree.attr) 158 for key, value in Tree.childs.items(): 159 stack.append(key) 160 print_tree(value, stack) 161 stack.pop() 162 stack.pop() 163 164 def classify(Tree, instance): 165 if (None == Tree): 166 return None 167 if (None != Tree.result): 168 return Tree.result 169 if instance[Tree.attr] in Tree.childs: 170 return classify(Tree.childs[instance[Tree.attr]], instance) 171 else: 172 return None 173 174 dataset = [ 175 ("青年", "否", "否", "一般", "否") 176 ,("青年", "否", "否", "好", "否") 177 ,("青年", "是", "否", "好", "是") 178 ,("青年", "是", "是", "一般", "是") 179 ,("青年", "否", "否", "一般", "否") 180 ,("中年", "否", "否", "一般", "否") 181 ,("中年", "否", "否", "好", "否") 182 ,("老年", "是", "否", "非常好", "是") 183 ,("老年", "否", "是", "一般", "是") 184 ,("老年", "否", "是", "一般", "是") 185 ,("老年", "否", "是", "一般", "是") 186 ,("老年", "否", "是", "好", "是") 187 ,("老年", "是", "否", "一般", "是") 188 ,("老年", "是", "否", "一般", "否") 189 ,("老年", "否", "否", "一般", "否") 190 ] 191 192 s = set(range(0, len(dataset[0]) - 1)) 193 s = set([0,1,3,4]) 194 T = build_tree(dataset, s) 195 print_tree(T) 196 print(classify(T, ("老年", "是", "否", "一般", "否"))) 197 print(classify(T, ("老年", "是", "否", "一般", "是"))) 198 print(classify(T, ("老年", "是", "是", "好", "否"))) 199 print(classify(T, ("青年", "是", "否", "好", "是"))) 200 print(classify(T, ("中年", "是", "否", "好", "否")))
该python程序的训练集不是上面给定的这个列子,输出结果为:
0 = 青年 ∧ 1 = 否 ∧ target = 否 0 = 青年 ∧ 1 = 是 ∧ target = 是 0 = 中年 ∧ target = 否 0 = 老年 ∧ 3 = 好 ∧ target = 是 0 = 老年 ∧ 3 = 非常好 ∧ target = 是 0 = 老年 ∧ 3 = 一般 ∧ 4 = 否 ∧ target = 否 0 = 老年 ∧ 3 = 一般 ∧ 4 = 是 ∧ target = 是 否 是 是 是 否 [Finished in 0.3s]
Java程序,该程序的数据集是上面给定的例子,代码及结果如下:
1 2 3 import java.util.ArrayList; 4 import java.util.Collection; 5 import java.util.Deque; 6 import java.util.HashMap; 7 import java.util.LinkedList; 8 import java.util.List; 9 import java.util.Map; 10 11 public class ID3Tree { 12 private List<String[]> datas; 13 private List<Integer> attributes; 14 private double threshold = 0.0001; 15 private int targetIndex = 1; 16 private Node tree; 17 private Map<Integer, String> attributeMap; 18 19 protected ID3Tree() { 20 super(); 21 } 22 23 public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, int targetIndex) { 24 this(datas, attributes, attributeMap, 0.0001, targetIndex, null); 25 } 26 27 public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, double threshold, int targetIndex, Node tree) { 28 super(); 29 this.datas = datas; 30 this.attributes = attributes; 31 this.attributeMap = attributeMap; 32 this.threshold = threshold; 33 this.targetIndex = targetIndex; 34 this.tree = tree; 35 } 36 37 /** 38 * 节点对象 39 * 40 * @author Gerry.Liu 41 * 42 */ 43 class Node { 44 private List<String[]> dataset; // 落在该节点上的训练实训集 45 private String result; // 结果类标签 46 private int attr; // 该节点的分裂属性ID,下标 47 private Node parent; // 该节点的父节点 48 private Map<String, List<Node>> childs; // 该节点的子节点集合 49 50 public Node(List<String[]> datas, Node parent) { 51 this.dataset = datas; 52 this.parent = parent; 53 this.childs = new HashMap<>(); 54 } 55 } 56 57 class KeyValue { 58 private String first; 59 private String second; 60 61 public KeyValue(String first, String second) { 62 super(); 63 this.first = first; 64 this.second = second; 65 } 66 67 @Override 68 public int hashCode() { 69 final int prime = 31; 70 int result = 1; 71 result = prime * result + getOuterType().hashCode(); 72 result = prime * result + ((first == null) ? 0 : first.hashCode()); 73 result = prime * result + ((second == null) ? 0 : second.hashCode()); 74 return result; 75 } 76 77 @Override 78 public boolean equals(Object obj) { 79 if (this == obj) 80 return true; 81 if (obj == null) 82 return false; 83 if (getClass() != obj.getClass()) 84 return false; 85 KeyValue other = (KeyValue) obj; 86 if (!getOuterType().equals(other.getOuterType())) 87 return false; 88 if (first == null) { 89 if (other.first != null) 90 return false; 91 } else if (!first.equals(other.first)) 92 return false; 93 if (second == null) { 94 if (other.second != null) 95 return false; 96 } else if (!second.equals(other.second)) 97 return false; 98 return true; 99 } 100 101 private ID3Tree getOuterType() { 102 return ID3Tree.this; 103 } 104 } 105 106 /** 107 * 根据概率计算信息熵,计算规则是:<br/> 108 * entropy(p1,p2....pn) = -p1*log2(p1) -p2*log2(p2)-.....-pn*log2(pn) 109 * 110 * @param props 111 * @return 112 */ 113 private double entropy(List<Double> props) { 114 if (props == null || props.isEmpty()) { 115 return 0; 116 } else { 117 double result = 0; 118 for (double p : props) { 119 if (p > 0) { 120 result = result - p * Math.log(p) / Math.log(2); 121 } 122 } 123 return result; 124 } 125 } 126 127 /** 128 * 计算概率 129 * 130 * @param totalRecords 131 * @param counts 132 * @return 133 */ 134 private List<Double> calcProbability(int totalRecords, Collection<Integer> counts) { 135 if (totalRecords == 0 || counts == null || counts.isEmpty()) { 136 return null; 137 } 138 139 List<Double> result = new ArrayList<>(); 140 for (int count : counts) { 141 result.add(1.0 * count / totalRecords); 142 } 143 return result; 144 } 145 146 /** 147 * 获取信息增益Gain(datas,attribute)<br/> 148 * 特征属性attribute(A)对训练数据集datas(D)的信息增益<br/> 149 * g(D,A) = entropy(D) - entropy(D|A)<br/> 150 * 151 * @param datas 152 * 训练数据集 153 * @param attributeIndex 154 * 特征属性下标 155 * @param targetAttributeIndex 156 * 目标属性下标 157 * @return 158 */ 159 private double infoGain(List<String[]> datas, int attributeIndex, int targetAttributeIndex) { 160 if (datas == null || datas.isEmpty()) { 161 return 0; 162 } 163 164 Map<String, Integer> targetAttributeCountMap = new HashMap<String, Integer>(); // 类别(目标属性)计数 165 Map<String, Integer> featureAttributesCountMap = new HashMap<>(); // 特征属性上的取值计数 166 Map<KeyValue, Integer> tfAttributeCountMap = new HashMap<>(); // 类别和特征属性的不同组合的计数 167 168 for (String[] arrs : datas) { 169 String tv = arrs[targetAttributeIndex]; 170 String fv = arrs[attributeIndex]; 171 if (targetAttributeCountMap.containsKey(tv)) { 172 targetAttributeCountMap.put(tv, targetAttributeCountMap.get(tv) + 1); 173 } else { 174 targetAttributeCountMap.put(tv, 1); 175 } 176 if (featureAttributesCountMap.containsKey(fv)) { 177 featureAttributesCountMap.put(fv, featureAttributesCountMap.get(fv) + 1); 178 } else { 179 featureAttributesCountMap.put(fv, 1); 180 } 181 KeyValue key = new KeyValue(fv, tv); 182 if (tfAttributeCountMap.containsKey(key)) { 183 tfAttributeCountMap.put(key, tfAttributeCountMap.get(key) + 1); 184 } else { 185 tfAttributeCountMap.put(key, 1); 186 } 187 } 188 189 int totalDataSize = datas.size(); 190 // 计算概率 191 List<Double> probabilitys = calcProbability(totalDataSize, targetAttributeCountMap.values()); 192 // 计算目标属性的信息熵 193 double entropyDatas = this.entropy(probabilitys); 194 195 // 计算条件概率 196 // 第一步,计算目标属性的各种取值,在特征属性确定的条件下的情况 197 Map<String, List<Double>> pcda = new HashMap<>(); 198 for (Map.Entry<KeyValue, Integer> entry : tfAttributeCountMap.entrySet()) { 199 String key = entry.getKey().first; 200 double pca = 1.0 * entry.getValue() / featureAttributesCountMap.get(key); 201 if (pcda.containsKey(key)) { 202 pcda.get(key).add(pca); 203 } else { 204 List<Double> list = new ArrayList<Double>(); 205 list.add(pca); 206 pcda.put(key, list); 207 } 208 } 209 // 第二步,针对每个特征属性的值取信息熵,并获取平均熵 210 double conditionEntropy = 0.0; 211 for (Map.Entry<String, Integer> entry : featureAttributesCountMap.entrySet()) { 212 double p = 1.0 * entry.getValue() / totalDataSize; 213 double e = this.entropy(pcda.get(entry.getKey())); 214 conditionEntropy += e * p; 215 } 216 return entropyDatas - conditionEntropy; 217 } 218 219 /** 220 * 获取数据集中目标属性中,实例值个数最大的目标特征值 221 * 222 * @param datas 223 * @param targetAttributeIndex 224 * @return 225 */ 226 private String getResult(List<String[]> datas, int targetAttributeIndex) { 227 if (datas == null || datas.isEmpty()) { 228 return null; 229 } else { 230 String result = ""; 231 Map<String, Integer> countMap = new HashMap<>(); 232 for (String[] arr : datas) { 233 String key = arr[targetAttributeIndex]; 234 if (countMap.containsKey(key)) { 235 countMap.put(key, countMap.get(key) + 1); 236 } else { 237 countMap.put(key, 1); 238 } 239 } 240 241 int maxCount = -1; 242 for (Map.Entry<String, Integer> entry : countMap.entrySet()) { 243 if (entry.getValue() > maxCount) { 244 maxCount = entry.getValue(); 245 result = entry.getKey(); 246 } 247 } 248 return result; 249 } 250 } 251 252 /** 253 * 按照特征属性的值将数据集D分裂成为多个子集 254 * 255 * @param datas 256 * 数据集 257 * @param attributeIndex 258 * 特征属性下标 259 * @return 260 */ 261 private Map<String, List<String[]>> devideDatas(List<String[]> datas, int attributeIndex) { 262 Map<String, List<String[]>> subdatas = new HashMap<>(); 263 if (datas != null && !datas.isEmpty()) { 264 for (String[] arr : datas) { 265 String key = arr[attributeIndex]; 266 if (subdatas.containsKey(key)) { 267 subdatas.get(key).add(arr); 268 } else { 269 List<String[]> list = new ArrayList<>(); 270 list.add(arr); 271 subdatas.put(key, list); 272 } 273 } 274 } 275 return subdatas; 276 } 277 278 /** 279 * 打印决策树 280 * 281 * @param tree 282 * @param stock 283 */ 284 private void printTree(Node tree, Deque<Object> stock) { 285 if (tree == null) { 286 return; 287 } 288 289 if (tree.result != null) { 290 this.printBrance(stock, tree.result); 291 } else { 292 stock.push(this.attributeMap.get(tree.attr)); 293 for (Map.Entry<String, List<Node>> entry : tree.childs.entrySet()) { 294 stock.push(entry.getKey()); 295 for (Node node : entry.getValue()) { 296 this.printTree(node, stock); 297 } 298 stock.pop(); 299 } 300 stock.pop(); 301 } 302 } 303 304 /** 305 * 输出Node表示的决策树的规则 306 * 307 * @param tree 308 */ 309 private void printBrance(Deque<Object> stock, String target) { 310 StringBuffer sb = new StringBuffer(); 311 int odd = 0; 312 for (Object e : stock) { 313 sb.insert(0, odd == 0 ? "^" : "=").insert(0, e); 314 // sb.append(e).append(odd == 0 ? "=" : "^"); 315 odd = 1 - odd; 316 } 317 sb.append("target=").append(target); 318 System.out.println(sb.toString()); 319 } 320 321 /** 322 * 创建一个决策树 323 * 324 * @param datas 325 * @param attributes 326 * @param threshold 327 * @param targetIndex 328 * @param tree 329 * @return 330 */ 331 private Node buildTree(List<String[]> datas, List<Integer> attributes, double threshold, int targetIndex, Node tree) { 332 if (tree == null) { 333 tree = new Node(datas, null); 334 } 335 // 分隔数据集,返回的数据集为empty或者是有数据,不会为null 336 Map<String, List<String[]>> subDatas = this.devideDatas(datas, targetIndex); 337 if (subDatas.size() <= 1) { 338 // 这里只会有一个key 339 for (String key : subDatas.keySet()) { 340 tree.result = key; 341 } 342 } else if (attributes == null || attributes.size() < 1) { 343 // 没有特征集,那么直接获取最多的值 344 tree.result = this.getResult(datas, targetIndex); 345 } else { 346 double maxGain = 0; 347 int attr = 0; 348 349 for (int attribute : attributes) { 350 double gain = this.infoGain(datas, attribute, targetIndex); 351 if (gain > maxGain) { 352 maxGain = gain; 353 attr = attribute;// 最大的信息增益下标 354 } 355 } 356 357 if (maxGain < threshold) { 358 // 达到收益条件 359 tree.result = this.getResult(datas, targetIndex); 360 } else { 361 // 没有达到结束条件,继续进行 362 tree.attr = attr; 363 subDatas = this.devideDatas(datas, attr); 364 tree.dataset = null; 365 attributes.remove(Integer.valueOf(attr)); 366 for (String key : subDatas.keySet()) { 367 Node childTree = new Node(subDatas.get(key), tree); 368 if (tree.childs.containsKey(key)) { 369 tree.childs.get(key).add(childTree); 370 } else { 371 List<Node> childs = new ArrayList<>(); 372 childs.add(childTree); 373 tree.childs.put(key, childs); 374 } 375 this.buildTree(subDatas.get(key), attributes, threshold, targetIndex, childTree); 376 } 377 } 378 } 379 return tree; 380 } 381 382 /** 383 * 根据决策规则获取推荐值 384 * 385 * @param instance 386 * @return 387 */ 388 private String classify(Node tree, String[] instance) { 389 if (tree == null) { 390 return null; 391 } 392 if (tree.result != null) { 393 return tree.result; 394 } 395 if (tree.childs.containsKey(instance[tree.attr])) { 396 List<Node> nodes = tree.childs.get(instance[tree.attr]); 397 for (Node node : nodes) { 398 return this.classify(node, instance); 399 } 400 } 401 return null; 402 } 403 404 /** 405 * 生产决策树 406 */ 407 public void buildTree() { 408 this.tree = new Node(this.datas, null); 409 this.buildTree(datas, attributes, threshold, targetIndex, tree); 410 } 411 412 /** 413 * 打印生产的规则 414 */ 415 public void printTree() { 416 this.printTree(this.tree, new LinkedList<>()); 417 } 418 419 /** 420 * 获取推荐结果 421 * 422 * @param instance 423 * @return 424 */ 425 public String classify(String[] instance) { 426 return this.classify(this.tree, instance); 427 } 428 429 public static void main(String[] args) { 430 List<String[]> dataset = new ArrayList<>(); 431 dataset.add(new String[] { "1", "youth", "high", "no", "fair", "no" }); 432 dataset.add(new String[] { "2", "youth", "high", "no", "excellent", "no" }); 433 dataset.add(new String[] { "3", "middle_aged", "high", "no", "fair", "yes" }); 434 dataset.add(new String[] { "4", "senior", "medium", "no", "fair", "yes" }); 435 dataset.add(new String[] { "5", "senior", "low", "yes", "fair", "yes" }); 436 dataset.add(new String[] { "6", "senior", "low", "yes", "excellent", "no" }); 437 dataset.add(new String[] { "7", "middle_aged", "low", "yes", "excellent", "yes" }); 438 dataset.add(new String[] { "8", "youth", "medium", "no", "fair", "no" }); 439 dataset.add(new String[] { "9", "youth", "low", "yes", "fair", "yes" }); 440 dataset.add(new String[] { "10", "senior", "medium", "yes", "fair", "yes" }); 441 dataset.add(new String[] { "11", "youth", "medium", "yes", "excellent", "yes" }); 442 dataset.add(new String[] { "12", "middle_aged", "medium", "no", "excellent", "yes" }); 443 dataset.add(new String[] { "13", "middle_aged", "high", "yes", "fair", "yes" }); 444 dataset.add(new String[] { "14", "senior", "medium", "no", "excellent", "no" }); 445 446 List<Integer> attributes = new ArrayList<>(); 447 attributes.add(4); 448 attributes.add(1); 449 attributes.add(2); 450 attributes.add(3); 451 452 Map<Integer, String> attributeMap = new HashMap<>(); 453 attributeMap.put(1, "Age"); 454 attributeMap.put(2, "Income"); 455 attributeMap.put(3, "Student"); 456 attributeMap.put(4, "Credit_rating"); 457 458 int targetIndex = 5; 459 460 String[] instance = new String[] { "15", "youth", "medium", "yes", "fair" }; 461 462 ID3Tree tree = new ID3Tree(dataset, attributes,attributeMap, targetIndex); 463 System.out.println("start build the tree"); 464 tree.buildTree(); 465 System.out.println("completed build the tree, start print the tree"); 466 tree.printTree(); 467 System.out.println("start classify....."); 468 String result = tree.classify(instance); 469 System.out.println(result); 470 } 471 }
运行java程序的结果是:
Start build the tree..... Completed build the tree, start print the tree..... Age=youth^Student=yes^target=yes Age=youth^Student=no^target=no Age=middle_aged^target=yes Age=senior^Credit_rating=excellent^target=no Age=senior^Credit_rating=fair^target=yes start classify..... yes
五、ID3算法不足
ID3算法运行速度较慢,只能加载内存中的数据,处理的数据集相对于其他算法较小。