ID3算法(Java实现)
数据存储文件:buycomputer.properties
#数据个数 datanum=14 #属性及属性值 nodeAndAttribute=年龄:青/中/老,收入:高/中/低,学生:是/否,信誉:良/优,归类:买/不买 #数据 D1=青,高,否,良,不买 D2=青,高,否,优,不买 D3=中,高,否,良,买 D4=老,中,否,良,买 D5=老,低,是,良,买 D6=老,低,是,优,不买 D7=中,低,是,优,买 D8=青,中,否,良,不买 D9=青,低,是,良,买 D10=老,中,是,良,买 D11=青,中,是,优,买 D12=中,中,否,优,买 D13=中,高,是,良,买 D14=老,中,否,优,不买 D15=老,中,否,优,买
实体类:TreeNode.java
package com.id3.node; import java.util.HashMap; import java.util.Map; public class TreeNode { private String nodeName; private Map<String,Attributes> attributes; private double gain; public double getGain() { return gain; } public void setGain(double gain) { this.gain = gain; } public String getNodeName() { return nodeName; } public void setNodeName(String nodeName) { this.nodeName = nodeName; } public Map<String, Attributes> getAttributes() { return attributes; } public void setAttributes(Map<String, Attributes> attributes) { this.attributes = attributes; } @Override public String toString() { return "TreeNode [nodeName=" + nodeName + ", attributes=" + attributes + ", gain=" + gain + "]"; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((attributes == null) ? 0 : attributes.hashCode()); long temp; temp = Double.doubleToLongBits(gain); result = prime * result + (int) (temp ^ (temp >>> 32)); result = prime * result + ((nodeName == null) ? 0 : nodeName.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; TreeNode other = (TreeNode) obj; if (attributes == null) { if (other.attributes != null) return false; } else if (!attributes.equals(other.attributes)) return false; if (Double.doubleToLongBits(gain) != Double .doubleToLongBits(other.gain)) return false; if (nodeName == null) { if (other.nodeName != null) return false; } else if (!nodeName.equals(other.nodeName)) return false; return true; } } class Attributes{ private String attrName; private TreeNode nextNode; private String leafName; private int attrNum; private double h; Map<String, Integer> resultNum = new HashMap<String, Integer>(); public String getLeafName() { return leafName; } public void setLeafName(String leafName) { this.leafName = leafName; } public Map<String, Integer> getResultNum() { return resultNum; } public void setResultNum(Map<String, Integer> resultNum) { this.resultNum = resultNum; } public double getH() { return h; } public void setH(double h) { this.h = h; } public String getAttrName() { return attrName; } public void setAttrName(String attrName) { this.attrName = attrName; } public TreeNode getNextNode() { return nextNode; } public void setNextNode(TreeNode nextNode) { this.nextNode = nextNode; } public int getAttrNum() { return attrNum; } public void setAttrNum(int attrNum) { this.attrNum = attrNum; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((attrName == null) ? 0 : attrName.hashCode()); result = prime * result + attrNum; long temp; temp = Double.doubleToLongBits(h); result = prime * result + (int) (temp ^ (temp >>> 32)); result = prime * result + ((leafName == null) ? 0 : leafName.hashCode()); result = prime * result + ((nextNode == null) ? 0 : nextNode.hashCode()); result = prime * result + ((resultNum == null) ? 0 : resultNum.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; Attributes other = (Attributes) obj; if (attrName == null) { if (other.attrName != null) return false; } else if (!attrName.equals(other.attrName)) return false; if (attrNum != other.attrNum) return false; if (Double.doubleToLongBits(h) != Double.doubleToLongBits(other.h)) return false; if (leafName == null) { if (other.leafName != null) return false; } else if (!leafName.equals(other.leafName)) return false; if (nextNode == null) { if (other.nextNode != null) return false; } else if (!nextNode.equals(other.nextNode)) return false; if (resultNum == null) { if (other.resultNum != null) return false; } else if (!resultNum.equals(other.resultNum)) return false; return true; } @Override public String toString() { return "Attributes [attrName=" + attrName + ", nextNode=" + nextNode + ", leafName=" + leafName + ", attrNum=" + attrNum + ", h=" + h + ", resultNum=" + resultNum + "]"; } }
ID3算法:ID3Alogo.java
package com.id3.node; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileWriter; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Properties; /** * ID3算法 * @author JoMint * */ public class ID3Alogo { //存每个节点及其属性等相关变量 private List<TreeNode> treeList; //存数据集 private List<Map<String, String>> dataList; //遍历决策树时的开始节点 private Attributes startNode; //决策结果变量的值 private List<String> resultList; //结果属性节点 private TreeNode resultNode; //决策树 private String str; //构建决策树的开始调用方法 public void ID3(String id3Name,String readPath,String printPath){ //初始化成员变量 initElement(id3Name); //读数据 readData(readPath); //构建决策树 cusTree(dataList, treeList, startNode); //System.out.println(startNode.getNextNode().getAttributes().get("Overcast").getLeafName()); //遍历决策树,并把结果存入str中 printTree(startNode,""); //打印决策树 System.out.println(str); //输出决策树到文件 printTreetoTxt(printPath); } /** * 初始化成员变量 */ private void initElement(String id3Name) { //存每个节点及其属性等相关变量 treeList = new ArrayList<TreeNode>(); //存数据集 dataList = new ArrayList<Map<String,String>>(); //遍历决策树时的开始节点 startNode = new Attributes(); //决策结果变量的值 resultList = new ArrayList<String>(); //结果属性节点 TreeNode resultNode = null; //决策树 str = id3Name+"决策树:\r\n"; } /** * 读数据 */ private void readData(String path) { Map<String, String> dataMap; Map<String,Attributes> attrMap; TreeNode treeNode; int num; //创建读取properties文件的对象 Properties pro = new Properties(); try { //为了读取中文字符,将读取文件的类型改为字符流读取 InputStream inputStream = new FileInputStream(path); BufferedReader bf = new BufferedReader(new InputStreamReader(inputStream)); //加载数据文件 pro.load(bf); //读取数据总个数 num = Integer.parseInt(pro.getProperty("datanum")); //读取属性及属性值 String attribute = pro.getProperty("nodeAndAttribute"); //将每个属性分开,用数组存,遍历每个属性,再把每个属性的属性值分开,存到treeList中 String[] attArray = attribute.split(","); for (int i = 0; i < attArray.length; i++) { treeNode = new TreeNode(); String[] temp = attArray[i].split(":"); String nodeName = temp[0]; String[] attr = temp[1].split("/"); treeNode.setNodeName(nodeName); attrMap = new HashMap<String, Attributes>(); Attributes attributes; for (int j = 0; j < attr.length; j++) { //Map<String, Integer> map = new HashMap<String, Integer>(); attributes = new Attributes(); //map.put(attr[j], 0); attributes.setAttrName(attr[j]); attrMap.put(attr[j], attributes); //存入结果变量的值,为最后的判断做铺垫 if(i == attArray.length-1){ resultList.add(attr[j]); } } treeNode.setAttributes(attrMap); treeList.add(treeNode); } //遍历数据集,将数据按行存入dataList中 for (int i = 1; i <= num; i++) { dataMap = new HashMap<String, String>(); String key = "D"+i; String[] colline = pro.getProperty(key).split(","); //System.out.println(key+"=="+colline.length); for (int j = 0; j < treeList.size(); j++) { //System.out.println(treeList.size()); dataMap.put(treeList.get(j).getNodeName(), colline[j]); } dataList.add(dataMap); } //得到结果属性的名字 resultNode = treeList.get(treeList.size()-1); // System.out.println("************************resultNode==" + resultNode + "***********************"); } catch (FileNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } /** * 数据处理 * @param cdataList * @param ctreeList */ private List<List> dealData(List<Map<String, String>> dataList, List<TreeNode> treeList){ List<List> returnList= new ArrayList<List>(); int num = dataList.size(); /* * 统计数据集中每个属性的属性值个数 */ Map<String, Attributes> attrMap = new HashMap<String, Attributes>(); Map<String, Integer> resultMap; for (int i = 0; i < treeList.size(); i++) { for (int j = 0; j < dataList.size(); j++) { //获得当前数据集中当前列当前行的属性值 String key = dataList.get(j).get(treeList.get(i).getNodeName()); attrMap = treeList.get(i).getAttributes(); //System.out.println(attrMap.get(key)+"=="+key); //计算样本中对应的属性变量的个数 attrMap.get(key).setAttrNum(attrMap.get(key).getAttrNum()+1); //System.out.println("->"+attrMap.get(key)); //获得结果变量值 String result = dataList.get(j).get(treeList.get(treeList.size()-1).getNodeName()); resultMap = attrMap.get(key).getResultNum(); //如果包含这个结果变量,则数量上加1; 如果不包含,赋初值为1 if (resultMap.containsKey(result)) { resultMap.put(result, resultMap.get(result)+1); }else{ resultMap.put(result, 1); } } } /* * 计算熵 */ DecimalFormat df = new DecimalFormat("#.###"); for (int i = 0; i < treeList.size(); i++) { //遍历 Attributes //计算属性熵: gain double gain = 0.0; for (Map.Entry<String, Attributes> element : treeList.get(i).getAttributes().entrySet()) { Attributes attr = treeList.get(i).getAttributes().get(element.getKey()); Map<String, Integer> result = attr.getResultNum(); //遍历每个 Attributes 的 resultNum //计算属性值的熵 :h double h = 0.0; for (Map.Entry<String, Integer> element2 : result.entrySet()) { double resultNum = (double)result.get(element2.getKey()); double attrNum = (double)attr.getAttrNum(); resultNum = resultNum/attrNum; h -= (resultNum*(Math.log(resultNum)/Math.log((double)2))); h = Double.parseDouble(df.format(h)); attr.setH(h); //System.out.println("resultNum=========="+resultNum); } //System.out.println(" attr==>"+attr); gain += ((double)attr.getAttrNum()/num)*attr.getH(); gain = Double.parseDouble(df.format(gain)); //System.out.println("gain=="+gain); } treeList.get(i).setGain(gain); //System.out.println(" gain-->"+treeList.get(i)); } //将处理好的dataList和treeList放在returnList中返回 returnList.add(dataList); returnList.add(treeList); return returnList; // System.out.println("***************************************************+++++++↓"); // for (int i = 0; i < treeList.size(); i++) { // System.out.println(treeList.get(i)); // } // System.out.println(); // for (int i = 0; i < dataList.size(); i++) { // System.out.println(dataList.get(i)); // } // // System.out.println("================================================="+num+"条数据=="+treeList.size()+"个属性"); // System.out.println("***************************************************+++++++↑"); } /** * 构建决策树 * @param dataList * @param treeList */ @SuppressWarnings("unchecked") private void cusTree(List<Map<String, String>> dataList, List<TreeNode> treeList, Attributes cAttr){ List<List> curryList= new ArrayList<List>(); //处理数据 curryList = dealData(dataList, treeList); //从 curryList 中得到 dataList 和 treeList dataList = (List<Map<String, String>>)curryList.get(0); treeList = (List<TreeNode>)curryList.get(1); //判断当前处理的数据集中的决策结果,若决策结果相同的个数等于总的当前处理的数据集的条数,则遍历结束 //将当前的决策结果放入当前判断的属性值的后边 //返回到调用这个函数的父函数 for (TreeNode treeNode : treeList) { if (treeNode.getNodeName().equals(resultNode.getNodeName())) { for (String attr : resultList) { if (treeNode.getAttributes().get(attr).getAttrNum() == dataList.size()) { cAttr.setLeafName(attr); return; } } } } // System.out.println("=_=_=_=_=_=_=datalist==="+dataList); // System.out.println("=_=_=_=_=_=_=treelist==="+treeList); //寻找最优解 //得到根节点 TreeNode rootNode = treeList.get(0); for (TreeNode treeNode : treeList) { if(!treeNode.getNodeName().equals(treeList.get(treeList.size()-1).getNodeName())){ if(treeNode.getGain() < rootNode.getGain()){ rootNode = treeNode; } } } // System.out.println("*********↓↓↓↓↓↓↓↓***********当前根节点为:"+rootNode.getNodeName()+"***********↓↓↓↓↓↓↓↓*********"); cAttr.setNextNode(rootNode); //对当前根节点的属性进行遍历,寻找下一个节点 //节点名 String nodeName = rootNode.getNodeName(); //属性名 String attrName = ""; //属性节点 Attributes attr = new Attributes(); //当前节点的属性值集合 Map<String, Attributes> attrMap = rootNode.getAttributes(); //遍历节点的每个属性值 for (Map.Entry<String, Attributes> entry : attrMap.entrySet()) { attr = attrMap.get(entry.getKey()); attrName = attr.getAttrName(); // System.out.println("*****************attrName========"+attrName+"******************"); //得到新的data集合对象 List<Map<String, String>> newDataList = new ArrayList<Map<String,String>>(); Map<String, String> newMap = new HashMap<String, String>(); //String attrName = rootNode.getAttributes().get("Sunny").getAttrName(); newMap.clear(); //删除dataList中已处理过的节点数据 //遍历dataList for (Map<String, String> map : dataList) { if(map.containsKey(nodeName)){ if(map.get(nodeName).equals(attrName)){ newMap = new HashMap<String, String>(); for (Map.Entry<String, String> m : map.entrySet()) { //如果该节点不是已处理过的节点 if(!m.getKey().equals(nodeName)){ //得到新的节点 newMap.put(m.getKey(), map.get(m.getKey())); } } //将新的节点存入newDataList中 newDataList.add(newMap); } } } // System.out.println("↓↓↓↓↓↓*******************新的data集合:*******************↓↓↓↓↓↓"); // for (Map<String, String> map : newDataList) { // System.out.println(map); // } //获得新的tree集合对象,而且值为初值 List<TreeNode> newTreeList = new ArrayList<TreeNode>(); //将treeList中的数据清空 clearTree(treeList); //删除treeList中已处理过的节点 for (TreeNode treeNode : treeList) { if(!treeNode.getNodeName().equals(nodeName)){ newTreeList.add(treeNode); } } // System.out.println("↓↓↓↓↓↓*******************新的tree集合:*******************↓↓↓↓↓↓"); // for (TreeNode treeNode : newTreeList) { // System.out.println(treeNode); // } //递归调用当前函数,继续找节点 cusTree(newDataList, newTreeList,attr); } } /** * 输出决策树 * @param attr */ private void printTree(Attributes attr, String ceil) { String nodeName = attr.getNextNode().getNodeName(); Map<String, Attributes> attrMap = attr.getNextNode().getAttributes(); str += ceil+"----"+nodeName+"\r\n"; for (Map.Entry<String, Attributes> nextAttr : attrMap.entrySet()) { //如果当前属性值没有下一个节点,则将当前属性值的名称及决策结果输出 if(attrMap.get(nextAttr.getKey()).getNextNode() == null){ str += ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"\r\n"; str += ceil+"----------"+attrMap.get(nextAttr.getKey()).getLeafName()+"\r\n"; }else{ str += ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"\r\n"; printTree(attrMap.get(nextAttr.getKey()),"------"); } } } /** * 打印决策树到txt文本 * @param path */ private void printTreetoTxt(String path){ if(path == null || path.equals("")) return; File file = new File(path); File folder = file.getParentFile(); FileWriter fw; try { if(!folder.exists()){ folder.mkdirs(); file.createNewFile(); } fw = new FileWriter(file); fw.write(str); fw.flush(); fw.close(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } /** * 还原初始数据 * @param treeList */ private void clearTree(List<TreeNode> treeList){ for (TreeNode treeNode : treeList) { Map<String, Attributes> map = treeNode.getAttributes(); for (Map.Entry<String, Attributes> entry : map.entrySet()) { Attributes attr = map.get(entry.getKey()); attr.setAttrNum(0); attr.setH(0); Map<String, Integer> map2 = attr.getResultNum(); map2.clear(); } treeNode.setGain(0); } } }
主函数:ID3Main.java
package com.id3.node; public class ID3Main { public static void main(String[] args) { ID3Alogo id3Alogo = new ID3Alogo(); id3Alogo.ID3("决策树名","数据文件地址", "输出文件地址"); } }