ID3决策树的Java实现
package DecisionTree; import java.io.*; import java.util.*; public class ID3 { //节点类 public class DTNode { private String attribute; private HashMap<String, DTNode> children = new HashMap<String, DTNode>(); public String getAttribute() { return attribute; } public void setAttribute(String attribute) { this.attribute = attribute; } public HashMap<String, DTNode> getChildren() { return children; } public void setChildren(HashMap<String, DTNode> children) { this.children = children; } } private String decisionColumn; //决定字段 public String getDecisionColumn() { return decisionColumn; } public void setDecisionColumn(String decisionColumn) { this.decisionColumn = decisionColumn; } //统计每个属性在集合中出现的次数 public HashMap<String, Integer> getTypeCounts(ArrayList<String> dataset) { HashMap<String, Integer> map = new HashMap<String, Integer>(); for (int i = 0; i < dataset.size(); i++) { String key = dataset.get(i); if(!map.containsKey(key)) map.put(key, 1); else map.put(key, map.get(key)+1); } return map; } //获取key的indexlist public ArrayList<Integer> getIndex(String key, ArrayList<String> dataset){ ArrayList<Integer> indexlist = new ArrayList<Integer>(); for(int i = 0; i < dataset.size(); i++){ if(key.equals(dataset.get(i))) indexlist.add(Integer.valueOf(i)); } return indexlist; } //根据index获取数据集 public ArrayList<String> getSubset(ArrayList<Integer> indexlist, ArrayList<String> dataset) { ArrayList<String> subset = new ArrayList<String>(); for(Integer i : indexlist){ subset.add(dataset.get(i.intValue())); } return subset; } //计算信息熵 public double getEntropy(ArrayList<String> dataset) { double entropy = 0; double prob = 0; int sum = dataset.size(); HashMap<String, Integer> map = getTypeCounts(dataset); Iterator<String> iter = map.keySet().iterator(); while(iter.hasNext()){ String key = iter.next(); prob = (double)map.get(key).intValue()/sum; entropy += -1*prob*Math.log10(prob)/Math.log10(2); } return entropy; } //计算已知条件下的信息熵 public double getConditionEntropy(HashMap<String, ArrayList<String>> dataset, String IndexCol) { double entropy = 0; double prob = 0; int sum = dataset.get(IndexCol).size(); HashMap<String, Integer> map = getTypeCounts(dataset.get(IndexCol)); Iterator<String> iter = map.keySet().iterator(); while(iter.hasNext()){ String key = iter.next(); prob = (double)map.get(key)/sum; entropy+=prob*getEntropy(getSubset(getIndex(key,dataset.get(IndexCol)),dataset.get(this.decisionColumn))); } return entropy; } //建立决策树 public DTNode buildDT(HashMap<String, ArrayList<String>>dataset) { DTNode node = new DTNode(); double info_entropy = getEntropy(dataset.get(this.decisionColumn)); //递归结束条件 if(info_entropy == 0){ node.setAttribute((dataset.get(this.decisionColumn).get(0))); return node; } //求出拥有最小熵数据集的column,即最大entropy gain double max_gain = 0; //设置默认值 double gain = 0; String max_column=""; Iterator<String> entropy_iter = dataset.keySet().iterator(); while(entropy_iter.hasNext()){ String key = entropy_iter.next(); if(key.equals(this.decisionColumn)) continue; gain = getEntropy(dataset.get(decisionColumn)) - getConditionEntropy(dataset,key); //计算信息增益 if(gain > max_gain){ max_gain = gain; max_column = key; } } node.setAttribute(max_column); ArrayList<String> ds = dataset.get(max_column); //最小熵数据集 //生成新数据集 Iterator<String> iter = getTypeCounts(ds).keySet().iterator(); while(iter.hasNext()){ String key = iter.next(); HashMap<String, ArrayList<String>> subset = new HashMap<String, ArrayList<String>>(); DTNode childNode; ArrayList<Integer> indexlist = getIndex(key,ds); Iterator<String> sub_iter = dataset.keySet().iterator(); while(sub_iter.hasNext()){ String sub_key = sub_iter.next(); if(!sub_key.equals(max_column)) subset.put(sub_key, getSubset(indexlist,dataset.get(sub_key))); } childNode = buildDT(subset); node.getChildren().put(key, childNode); } return node; } //输出树 public void printDT(DTNode root){ if(root == null) return; System.out.println(root.attribute); if(root.getChildren() == null) return; Iterator<String> iter = root.getChildren().keySet().iterator(); while(iter.hasNext()){ String key = iter.next(); System.out.print(key+" "); printDT(root.getChildren().get(key)); } } //读取源文件 public HashMap<String,ArrayList<String>> read(String path){ HashMap<String,ArrayList<String>> dataset = new HashMap<String,ArrayList<String>>(); try{ File file = new File(path); if(file.isFile() && file.exists()){ //判断文件是否存在 InputStreamReader input = new InputStreamReader(new FileInputStream(file),"UTF-8"); BufferedReader read = new BufferedReader(input); String line = null; ArrayList<ArrayList<String>> ds = new ArrayList<ArrayList<String>>(); while((line = read.readLine()) != null){ String[] data = line.split(","); ArrayList<String> temp = new ArrayList<String>(); for(int i = 0; i < data.length; i++) temp.add(data[i]); ds.add(temp); } for(int i = 0; i < ds.get(0).size(); i++){ ArrayList<String> newds = new ArrayList<String>(); for(int j = 0; j < ds.size(); j++){ newds.add(ds.get(j).get(i)); } String key = newds.get(0); newds.remove(0); dataset.put(key,newds); } input.close(); } }catch(Exception e){ e.printStackTrace(); } return dataset; } public static void main(String[] args) { ID3 tree = new ID3(); HashMap<String,ArrayList<String>> ds = tree.read("C:"+File.separator+"Users"+File.separator+"mhua005"+File.separator+ "Desktop"+File.separator+"sample.txt"); tree.setDecisionColumn("play"); ArrayList<String> attr = new ArrayList<String>(); attr.add("outlook"); attr.add("temperature"); attr.add("humidity"); attr.add("windy"); attr.add("play"); DTNode root = tree.buildDT(ds); tree.printDT(root); } }
源文件内容:
outlook,temperature,humidity,windy,play
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no