ID3决策树的Java实现

package DecisionTree;

import java.io.*;
import java.util.*;


public class ID3 {

    //节点类
    public class DTNode {
        private String attribute;
        private HashMap<String, DTNode> children = new HashMap<String, DTNode>();
        public String getAttribute() {
            return attribute;
        }
        public void setAttribute(String attribute) {
            this.attribute = attribute;
        }
        public HashMap<String, DTNode> getChildren() {
            return children;
        }
        public void setChildren(HashMap<String, DTNode> children) {
            this.children = children;
        }        
    }
    
    private String decisionColumn;        //决定字段

    public String getDecisionColumn() {
        return decisionColumn;
    }

    public void setDecisionColumn(String decisionColumn) {
        this.decisionColumn = decisionColumn;
    }

    //统计每个属性在集合中出现的次数
    public HashMap<String, Integer> getTypeCounts(ArrayList<String> dataset) {
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        for (int i = 0; i < dataset.size(); i++) {
            String key = dataset.get(i);
            if(!map.containsKey(key))
                map.put(key, 1);
            else
                map.put(key, map.get(key)+1);
        }
        return map;
    }
    
    //获取key的indexlist
    public ArrayList<Integer> getIndex(String key, ArrayList<String> dataset){
        ArrayList<Integer> indexlist = new ArrayList<Integer>();
        
        for(int i = 0; i < dataset.size(); i++){
            if(key.equals(dataset.get(i)))
                indexlist.add(Integer.valueOf(i));
        }
        return indexlist;
    }
    
    //根据index获取数据集
    public ArrayList<String> getSubset(ArrayList<Integer> indexlist, ArrayList<String> dataset) {
        ArrayList<String> subset = new ArrayList<String>();
        for(Integer i : indexlist){
            subset.add(dataset.get(i.intValue()));
        }
        return subset;
    }
    
    
    //计算信息熵
    public double getEntropy(ArrayList<String> dataset) {
        double entropy = 0;
        double prob = 0;
        int sum = dataset.size();
        HashMap<String, Integer> map = getTypeCounts(dataset);
        Iterator<String> iter = map.keySet().iterator();
        while(iter.hasNext()){
            String key = iter.next();
            prob = (double)map.get(key).intValue()/sum;
            entropy += -1*prob*Math.log10(prob)/Math.log10(2);
        }    
        return entropy;
    }
    
    //计算已知条件下的信息熵
    public double getConditionEntropy(HashMap<String, ArrayList<String>> dataset, String IndexCol) {
        double entropy = 0;
        double prob = 0;
        int sum = dataset.get(IndexCol).size();
        HashMap<String, Integer> map = getTypeCounts(dataset.get(IndexCol));
        Iterator<String> iter = map.keySet().iterator();
        while(iter.hasNext()){
            String key = iter.next();
            prob = (double)map.get(key)/sum;
            entropy+=prob*getEntropy(getSubset(getIndex(key,dataset.get(IndexCol)),dataset.get(this.decisionColumn)));
        }
        return entropy;
    }
    
    //建立决策树    
    public DTNode buildDT(HashMap<String, ArrayList<String>>dataset) {
        
        DTNode node = new DTNode();
        double info_entropy = getEntropy(dataset.get(this.decisionColumn));
        //递归结束条件
        if(info_entropy == 0){
            node.setAttribute((dataset.get(this.decisionColumn).get(0)));    
            return node;
        }
        
        //求出拥有最小熵数据集的column,即最大entropy gain
        double max_gain = 0;            //设置默认值
        double gain = 0;
        String max_column="";
        Iterator<String> entropy_iter = dataset.keySet().iterator();
        
        while(entropy_iter.hasNext()){
            String key = entropy_iter.next();
            if(key.equals(this.decisionColumn))
                continue;
            gain = getEntropy(dataset.get(decisionColumn)) - getConditionEntropy(dataset,key);  //计算信息增益
            if(gain > max_gain){
                max_gain = gain;
                max_column = key;
            }    
        }

        node.setAttribute(max_column);
        ArrayList<String> ds = dataset.get(max_column);        //最小熵数据集
        
        //生成新数据集
        Iterator<String> iter = getTypeCounts(ds).keySet().iterator();
        while(iter.hasNext()){
            String key = iter.next();
            HashMap<String, ArrayList<String>> subset = new HashMap<String, ArrayList<String>>();
            DTNode childNode;
            ArrayList<Integer> indexlist = getIndex(key,ds);
            Iterator<String> sub_iter = dataset.keySet().iterator();
            while(sub_iter.hasNext()){
                String sub_key = sub_iter.next();
                if(!sub_key.equals(max_column))
                    subset.put(sub_key, getSubset(indexlist,dataset.get(sub_key)));
            }
            
            childNode = buildDT(subset);
            node.getChildren().put(key, childNode);
        }
        
        return node;
    }
    
    //输出树
    public void printDT(DTNode root){

        if(root == null)
            return;
        System.out.println(root.attribute);
        if(root.getChildren() == null)
            return;
        
        Iterator<String> iter = root.getChildren().keySet().iterator();
        while(iter.hasNext()){
            String key = iter.next();
            System.out.print(key+" ");
            printDT(root.getChildren().get(key));
        }
    }
    
    
    //读取源文件
    public HashMap<String,ArrayList<String>> read(String path){
        HashMap<String,ArrayList<String>> dataset = new HashMap<String,ArrayList<String>>();
        try{
            File file = new File(path);
            if(file.isFile() && file.exists()){ //判断文件是否存在
                InputStreamReader input = new InputStreamReader(new FileInputStream(file),"UTF-8");
                BufferedReader read = new BufferedReader(input);
                String line = null;
                
                ArrayList<ArrayList<String>> ds = new ArrayList<ArrayList<String>>();
                while((line = read.readLine()) != null){
                    String[] data = line.split(",");
                    ArrayList<String> temp = new ArrayList<String>();
                    for(int i = 0; i < data.length; i++)
                        temp.add(data[i]);
                    ds.add(temp);
                }
                
                for(int i = 0; i < ds.get(0).size(); i++){
                    ArrayList<String> newds = new ArrayList<String>();
                    for(int j = 0; j < ds.size(); j++){
                        newds.add(ds.get(j).get(i));
                    }
                    String key = newds.get(0);
                    newds.remove(0);
                    dataset.put(key,newds);
                }
                input.close();      
            }
        }catch(Exception e){
            e.printStackTrace();
        }
        
        return dataset;
    }
    
    
    public static void main(String[] args) {
        ID3 tree = new ID3();
        HashMap<String,ArrayList<String>> ds = tree.read("C:"+File.separator+"Users"+File.separator+"mhua005"+File.separator+
                "Desktop"+File.separator+"sample.txt");
        tree.setDecisionColumn("play");
        ArrayList<String> attr = new ArrayList<String>();
        attr.add("outlook");
        attr.add("temperature");
        attr.add("humidity");
        attr.add("windy");
        attr.add("play");
        DTNode root = tree.buildDT(ds);
        tree.printDT(root);
    }
}

 

源文件内容:

outlook,temperature,humidity,windy,play
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no

 

posted @ 2016-03-10 15:42  finalboss1987  阅读(463)  评论(0编辑  收藏  举报