ID3算法(Java实现)

数据存储文件:buycomputer.properties

#数据个数
datanum=14
#属性及属性值
nodeAndAttribute=年龄:青/中/老,收入:高/中/低,学生:是/否,信誉:良/优,归类:买/不买
#数据
D1=青,高,否,良,不买
D2=青,高,否,优,不买
D3=中,高,否,良,买
D4=老,中,否,良,买
D5=老,低,是,良,买
D6=老,低,是,优,不买
D7=中,低,是,优,买
D8=青,中,否,良,不买
D9=青,低,是,良,买
D10=老,中,是,良,买
D11=青,中,是,优,买
D12=中,中,否,优,买
D13=中,高,是,良,买
D14=老,中,否,优,不买
D15=老,中,否,优,买
View Code

实体类:TreeNode.java

package com.id3.node;

import java.util.HashMap;
import java.util.Map;

public class TreeNode {

    private String nodeName;
    private Map<String,Attributes> attributes;
    private double gain;
    
    public double getGain() {
        return gain;
    }
    public void setGain(double gain) {
        this.gain = gain;
    }
    public String getNodeName() {
        
        return nodeName;
    }
    public void setNodeName(String nodeName) {
        this.nodeName = nodeName;
    }
    public Map<String, Attributes> getAttributes() {
        return attributes;
    }
    public void setAttributes(Map<String, Attributes> attributes) {
        
        this.attributes = attributes;
    }
    
    @Override
    public String toString() {
        return "TreeNode [nodeName=" + nodeName + ", attributes=" + attributes
                + ", gain=" + gain + "]";
    }
    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result
                + ((attributes == null) ? 0 : attributes.hashCode());
        long temp;
        temp = Double.doubleToLongBits(gain);
        result = prime * result + (int) (temp ^ (temp >>> 32));
        result = prime * result
                + ((nodeName == null) ? 0 : nodeName.hashCode());
        return result;
    }
    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        TreeNode other = (TreeNode) obj;
        if (attributes == null) {
            if (other.attributes != null)
                return false;
        } else if (!attributes.equals(other.attributes))
            return false;
        if (Double.doubleToLongBits(gain) != Double
                .doubleToLongBits(other.gain))
            return false;
        if (nodeName == null) {
            if (other.nodeName != null)
                return false;
        } else if (!nodeName.equals(other.nodeName))
            return false;
        return true;
    }
    
    
    
    
}


class Attributes{
    
    private String attrName;
    private TreeNode nextNode;
    private String leafName;
    private int attrNum;
    private double h;
    Map<String, Integer> resultNum = new HashMap<String, Integer>();
    
    public String getLeafName() {
        return leafName;
    }
    public void setLeafName(String leafName) {
        this.leafName = leafName;
    }
    public Map<String, Integer> getResultNum() {
        return resultNum;
    }
    public void setResultNum(Map<String, Integer> resultNum) {
        this.resultNum = resultNum;
    }
    public double getH() {
        return h;
    }
    public void setH(double h) {
        this.h = h;
    }
    public String getAttrName() {
        return attrName;
    }
    public void setAttrName(String attrName) {
        this.attrName = attrName;
    }
    public TreeNode getNextNode() {
        return nextNode;
    }
    public void setNextNode(TreeNode nextNode) {
        this.nextNode = nextNode;
    }
    public int getAttrNum() {
        return attrNum;
    }
    public void setAttrNum(int attrNum) {
        this.attrNum = attrNum;
    }
    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result
                + ((attrName == null) ? 0 : attrName.hashCode());
        result = prime * result + attrNum;
        long temp;
        temp = Double.doubleToLongBits(h);
        result = prime * result + (int) (temp ^ (temp >>> 32));
        result = prime * result
                + ((leafName == null) ? 0 : leafName.hashCode());
        result = prime * result
                + ((nextNode == null) ? 0 : nextNode.hashCode());
        result = prime * result
                + ((resultNum == null) ? 0 : resultNum.hashCode());
        return result;
    }
    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        Attributes other = (Attributes) obj;
        if (attrName == null) {
            if (other.attrName != null)
                return false;
        } else if (!attrName.equals(other.attrName))
            return false;
        if (attrNum != other.attrNum)
            return false;
        if (Double.doubleToLongBits(h) != Double.doubleToLongBits(other.h))
            return false;
        if (leafName == null) {
            if (other.leafName != null)
                return false;
        } else if (!leafName.equals(other.leafName))
            return false;
        if (nextNode == null) {
            if (other.nextNode != null)
                return false;
        } else if (!nextNode.equals(other.nextNode))
            return false;
        if (resultNum == null) {
            if (other.resultNum != null)
                return false;
        } else if (!resultNum.equals(other.resultNum))
            return false;
        return true;
    }
    @Override
    public String toString() {
        return "Attributes [attrName=" + attrName + ", nextNode=" + nextNode
                + ", leafName=" + leafName + ", attrNum=" + attrNum + ", h="
                + h + ", resultNum=" + resultNum + "]";
    }
    
}
View Code

ID3算法:ID3Alogo.java

package com.id3.node;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
 * ID3算法
 * @author JoMint
 *
 */
public class ID3Alogo {

    //存每个节点及其属性等相关变量
    private List<TreeNode> treeList;
    //存数据集
    private List<Map<String, String>> dataList;
    //遍历决策树时的开始节点
    private Attributes startNode;
    //决策结果变量的值
    private List<String> resultList;
    //结果属性节点
    private TreeNode resultNode;
    //决策树
    private String str;

    //构建决策树的开始调用方法
    public void ID3(String id3Name,String readPath,String printPath){
        
        //初始化成员变量
        initElement(id3Name);
        //读数据
        readData(readPath);
        //构建决策树
        cusTree(dataList, treeList, startNode);
        //System.out.println(startNode.getNextNode().getAttributes().get("Overcast").getLeafName());
        //遍历决策树,并把结果存入str中
        
        printTree(startNode,"");
        //打印决策树
        System.out.println(str);
        //输出决策树到文件
        printTreetoTxt(printPath);
        
    }
    
    /**
     * 初始化成员变量
     */
    private void initElement(String id3Name) {
        
        //存每个节点及其属性等相关变量
        treeList = new ArrayList<TreeNode>();
        //存数据集
        dataList = new ArrayList<Map<String,String>>();
        //遍历决策树时的开始节点
        startNode = new Attributes();
        //决策结果变量的值
        resultList = new ArrayList<String>();
        //结果属性节点
        TreeNode resultNode = null;
        //决策树
        str = id3Name+"决策树:\r\n";
        
    }


    /**
     * 读数据
     */
    private void readData(String path) {
        
        Map<String, String> dataMap;
        Map<String,Attributes> attrMap;
        TreeNode treeNode;
        int num;
        
        //创建读取properties文件的对象
        Properties pro = new Properties();
        
        
        try {
            //为了读取中文字符,将读取文件的类型改为字符流读取
            InputStream inputStream = new FileInputStream(path);
            BufferedReader bf = new BufferedReader(new InputStreamReader(inputStream));
            //加载数据文件
            pro.load(bf);
            //读取数据总个数
            num = Integer.parseInt(pro.getProperty("datanum"));
            //读取属性及属性值
            String attribute = pro.getProperty("nodeAndAttribute");
            //将每个属性分开,用数组存,遍历每个属性,再把每个属性的属性值分开,存到treeList中
            String[] attArray = attribute.split(",");
            for (int i = 0; i < attArray.length; i++) {
                
                treeNode = new TreeNode();
                String[] temp = attArray[i].split(":");
                String nodeName = temp[0];
                String[] attr = temp[1].split("/");
                treeNode.setNodeName(nodeName);
                attrMap = new HashMap<String, Attributes>();
                Attributes attributes;
                for (int j = 0; j < attr.length; j++) {
                    //Map<String, Integer> map = new HashMap<String, Integer>();
                    attributes = new Attributes();
                    //map.put(attr[j], 0);
                    attributes.setAttrName(attr[j]);
                    attrMap.put(attr[j], attributes);
                    
                    //存入结果变量的值,为最后的判断做铺垫
                    if(i == attArray.length-1){
                        
                        resultList.add(attr[j]);
                        
                    }
                    
                }
                treeNode.setAttributes(attrMap);
                treeList.add(treeNode);
            }
            
            //遍历数据集,将数据按行存入dataList中
            for (int i = 1; i <= num; i++) {
                
                dataMap = new HashMap<String, String>();
                String key = "D"+i;
                String[] colline = pro.getProperty(key).split(",");
                //System.out.println(key+"=="+colline.length);
                for (int j = 0; j < treeList.size(); j++) {
                    //System.out.println(treeList.size());
                    dataMap.put(treeList.get(j).getNodeName(), colline[j]);
                }
                dataList.add(dataMap);
            }
            
            //得到结果属性的名字
            resultNode = treeList.get(treeList.size()-1);
            
            
//            System.out.println("************************resultNode==" + resultNode + "***********************");
        } catch (FileNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

    }

    
    /**
     * 数据处理
     * @param cdataList
     * @param ctreeList
     */
    private List<List> dealData(List<Map<String, String>> dataList, List<TreeNode> treeList){
        
        
        List<List> returnList= new ArrayList<List>();
        int num = dataList.size();
        
        /*
         * 统计数据集中每个属性的属性值个数
         */
        Map<String, Attributes> attrMap = new HashMap<String, Attributes>();
        Map<String, Integer> resultMap;
        for (int i = 0; i < treeList.size(); i++) {
            for (int j = 0; j < dataList.size(); j++) {
                //获得当前数据集中当前列当前行的属性值
                String key = dataList.get(j).get(treeList.get(i).getNodeName()); 
                attrMap = treeList.get(i).getAttributes();
                //System.out.println(attrMap.get(key)+"=="+key);
                //计算样本中对应的属性变量的个数
                attrMap.get(key).setAttrNum(attrMap.get(key).getAttrNum()+1); 
                
                //System.out.println("->"+attrMap.get(key));
                
                //获得结果变量值
                String result = dataList.get(j).get(treeList.get(treeList.size()-1).getNodeName()); 
                resultMap = attrMap.get(key).getResultNum();
                //如果包含这个结果变量,则数量上加1; 如果不包含,赋初值为1
                if (resultMap.containsKey(result)) {      
                    resultMap.put(result, resultMap.get(result)+1);
                }else{
                    resultMap.put(result, 1);
                }
            }
        }
        /*
         * 计算熵
         */
        DecimalFormat df = new DecimalFormat("#.###");   
        for (int i = 0; i < treeList.size(); i++) {
            //遍历 Attributes
            //计算属性熵: gain
            double gain = 0.0;
            for (Map.Entry<String, Attributes> element : treeList.get(i).getAttributes().entrySet()) {
                Attributes attr = treeList.get(i).getAttributes().get(element.getKey());
                Map<String, Integer> result = attr.getResultNum();
                //遍历每个 Attributes 的 resultNum
                //计算属性值的熵 :h
                double h = 0.0;
                for (Map.Entry<String, Integer> element2 : result.entrySet()) {
                    double resultNum = (double)result.get(element2.getKey());
                    double attrNum = (double)attr.getAttrNum();
                    resultNum = resultNum/attrNum;
                    h -= (resultNum*(Math.log(resultNum)/Math.log((double)2)));
                    h = Double.parseDouble(df.format(h));
                    attr.setH(h);
                    //System.out.println("resultNum=========="+resultNum);
                }
                //System.out.println(" attr==>"+attr);
                gain += ((double)attr.getAttrNum()/num)*attr.getH();
                gain = Double.parseDouble(df.format(gain));
                
                //System.out.println("gain=="+gain);
            }
            
            treeList.get(i).setGain(gain);
            //System.out.println(" gain-->"+treeList.get(i));
            
        }
        
        //将处理好的dataList和treeList放在returnList中返回
        returnList.add(dataList);
        returnList.add(treeList);
        
        return returnList;
        
//        System.out.println("***************************************************+++++++↓");
//        for (int i = 0; i < treeList.size(); i++) {
//            System.out.println(treeList.get(i));
//        }
//        System.out.println();
//        for (int i = 0; i < dataList.size(); i++) {
//            System.out.println(dataList.get(i));
//        }
//        
//        System.out.println("================================================="+num+"条数据=="+treeList.size()+"个属性");
//        System.out.println("***************************************************+++++++↑");
        
    }
    
    /**
     * 构建决策树
     * @param dataList
     * @param treeList
     */
    @SuppressWarnings("unchecked")
    private void cusTree(List<Map<String, String>> dataList, List<TreeNode> treeList, Attributes cAttr){
        
        List<List> curryList= new ArrayList<List>();
        
        //处理数据
        
        curryList = dealData(dataList, treeList);
        
        
        //从 curryList 中得到 dataList 和 treeList
        dataList = (List<Map<String, String>>)curryList.get(0);
        treeList = (List<TreeNode>)curryList.get(1);
        
        
        //判断当前处理的数据集中的决策结果,若决策结果相同的个数等于总的当前处理的数据集的条数,则遍历结束
        //将当前的决策结果放入当前判断的属性值的后边
        //返回到调用这个函数的父函数
        for (TreeNode treeNode : treeList) {
            if (treeNode.getNodeName().equals(resultNode.getNodeName())) {
                for (String attr : resultList) {
                    if (treeNode.getAttributes().get(attr).getAttrNum() == dataList.size()) {
                        cAttr.setLeafName(attr);
                        return;
                    }
                }
            }
        }
        
//        System.out.println("=_=_=_=_=_=_=datalist==="+dataList);
//        System.out.println("=_=_=_=_=_=_=treelist==="+treeList);
        
        //寻找最优解
        
        //得到根节点
        TreeNode rootNode = treeList.get(0);
        
        for (TreeNode treeNode : treeList) {
            
            if(!treeNode.getNodeName().equals(treeList.get(treeList.size()-1).getNodeName())){
                if(treeNode.getGain() < rootNode.getGain()){
                    rootNode = treeNode;
                }
            }
            
        }
    //    System.out.println("*********↓↓↓↓↓↓↓↓***********当前根节点为:"+rootNode.getNodeName()+"***********↓↓↓↓↓↓↓↓*********");
        
        cAttr.setNextNode(rootNode);
        
        //对当前根节点的属性进行遍历,寻找下一个节点
        
        //节点名
        String nodeName = rootNode.getNodeName();
        //属性名
        String attrName = "";
        //属性节点
        Attributes attr = new Attributes();
        //当前节点的属性值集合
        Map<String, Attributes> attrMap = rootNode.getAttributes();
        
        
        //遍历节点的每个属性值
        for (Map.Entry<String, Attributes> entry : attrMap.entrySet()) {
            
            attr = attrMap.get(entry.getKey());
            attrName = attr.getAttrName();
            
            
//            System.out.println("*****************attrName========"+attrName+"******************");
            
            //得到新的data集合对象
            
            List<Map<String, String>> newDataList = new ArrayList<Map<String,String>>();
            Map<String, String> newMap = new HashMap<String, String>();
            //String attrName = rootNode.getAttributes().get("Sunny").getAttrName();
            newMap.clear();
            
            //删除dataList中已处理过的节点数据
            //遍历dataList
            for (Map<String, String> map : dataList) {
                
                if(map.containsKey(nodeName)){
                    
                    if(map.get(nodeName).equals(attrName)){
                        newMap = new HashMap<String, String>();
                        for (Map.Entry<String, String> m : map.entrySet()) {
                            
                            //如果该节点不是已处理过的节点
                            if(!m.getKey().equals(nodeName)){
                                //得到新的节点
                                newMap.put(m.getKey(), map.get(m.getKey()));
                            }
                            
                        }
                        
                        //将新的节点存入newDataList中
                        newDataList.add(newMap);
                    }
                    
                }
                
            } 
//            System.out.println("↓↓↓↓↓↓*******************新的data集合:*******************↓↓↓↓↓↓");
//            for (Map<String, String> map : newDataList) {
//                System.out.println(map);
//            }
            
            //获得新的tree集合对象,而且值为初值
            
            List<TreeNode> newTreeList = new ArrayList<TreeNode>();
            
            //将treeList中的数据清空
            clearTree(treeList);
            
            //删除treeList中已处理过的节点
            for (TreeNode treeNode : treeList) {
                if(!treeNode.getNodeName().equals(nodeName)){
                    newTreeList.add(treeNode);
                }
            }
//            System.out.println("↓↓↓↓↓↓*******************新的tree集合:*******************↓↓↓↓↓↓");
//            for (TreeNode treeNode : newTreeList) {
//                System.out.println(treeNode);
//            }
            
            //递归调用当前函数,继续找节点
            cusTree(newDataList, newTreeList,attr);
        }
    }
    
    /**
     * 输出决策树
     * @param attr
     */
    private void printTree(Attributes attr, String ceil) {
        
        String nodeName = attr.getNextNode().getNodeName();
        Map<String, Attributes> attrMap = attr.getNextNode().getAttributes();
        
        str += ceil+"----"+nodeName+"\r\n";
        for (Map.Entry<String, Attributes> nextAttr : attrMap.entrySet()) {
            
            //如果当前属性值没有下一个节点,则将当前属性值的名称及决策结果输出
            if(attrMap.get(nextAttr.getKey()).getNextNode() == null){
                
                str += ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"\r\n";
                str += ceil+"----------"+attrMap.get(nextAttr.getKey()).getLeafName()+"\r\n";
                
            }else{
                
                str += ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"\r\n";
                printTree(attrMap.get(nextAttr.getKey()),"------");
            }
        }

    }
    
    /**
     * 打印决策树到txt文本
     * @param path
     */
    private void printTreetoTxt(String path){
        
        if(path == null || path.equals("")) return;
        File file = new File(path);
        File folder = file.getParentFile();
        FileWriter fw;
        try {
            
            if(!folder.exists()){
                folder.mkdirs();
                file.createNewFile();
            }
            
            fw = new FileWriter(file);
            fw.write(str);
            
            fw.flush();
            fw.close();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }
    
    
    /**
     * 还原初始数据
     * @param treeList
     */
    private void clearTree(List<TreeNode> treeList){
        
        for (TreeNode treeNode : treeList) {
            Map<String, Attributes> map = treeNode.getAttributes();
            
            for (Map.Entry<String, Attributes> entry : map.entrySet()) {
                Attributes attr = map.get(entry.getKey());
                attr.setAttrNum(0);
                attr.setH(0);
                Map<String, Integer> map2 = attr.getResultNum();
                map2.clear();
            }
            treeNode.setGain(0);
        }
    }
    
}
View Code

主函数:ID3Main.java

package com.id3.node;


public class ID3Main {

    public static void main(String[] args) {
        
        ID3Alogo id3Alogo = new ID3Alogo();
        id3Alogo.ID3("决策树名","数据文件地址", "输出文件地址");
        
    }
    
}
View Code

 

posted @ 2015-04-14 21:27  新绿薄荷  阅读(1648)  评论(3编辑  收藏  举报