以下程序是我练习写的,不一定正确也没做存储优化。有问题请留言交流。转载请挂连接。
当前的属性为:age income student credit_rating
当前的数据集为(最后一列是TARGET_VALUE):
---------------------------------
youth high no fair no
youth high no excellent no
middle_aged high no fair yes
senior low yes fair yes
senior low yes excellent no
middle_aged low yes excellent yes
youth medium no fair no
youth low yes fair yes
senior medium yes fair yes
youth medium yes excellent yes
middle_aged high yes fair yes
senior medium no excellent no
---------------------------------
C4.5建立树类
package C45Test; import java.util.ArrayList; import java.util.List; import java.util.Map; public class DecisionTree { public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList){ System.out.println("当前的DATA为"); for(int i=0;i<data.size();i++){ ArrayList<String> temp = data.get(i); for(int j=0;j<temp.size();j++){ System.out.print(temp.get(j)+ " "); } System.out.println(); } System.out.println("---------------------------------"); System.out.println("当前的ATTR为"); for(int i=0;i<attributeList.size();i++){ System.out.print(attributeList.get(i)+ " "); } System.out.println(); System.out.println("---------------------------------"); TreeNode node = new TreeNode(); String result = InfoGain.IsPure(InfoGain.getTarget(data)); if(result != null){ node.setNodeName("leafNode"); node.setTargetFunValue(result); return node; } if(attributeList.size() == 0){ node.setTargetFunValue(result); return node; }else{ InfoGain gain = new InfoGain(data,attributeList); double maxGain = 0.0; int attrIndex = -1; for(int i=0;i<attributeList.size();i++){ double tempGain = gain.getGainRatio(i); if(maxGain < tempGain){ maxGain = tempGain; attrIndex = i; } } System.out.println("选择出的最大增益率属性为: " + attributeList.get(attrIndex)); node.setAttributeValue(attributeList.get(attrIndex)); List<ArrayList<String>> resultData = null; Map<String,Long> attrvalueMap = gain.getAttributeValue(attrIndex); for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){ resultData = gain.getData4Value(entry.getKey(), attrIndex); TreeNode leafNode = null; System.out.println("当前为"+attributeList.get(attrIndex)+"的"+entry.getKey()+"分支。"); if(resultData.size() == 0){ leafNode = new TreeNode(); leafNode.setNodeName(attributeList.get(attrIndex)); leafNode.setTargetFunValue(result); leafNode.setAttributeValue(entry.getKey()); }else{ for (int j = 0; j < resultData.size(); j++) { resultData.get(j).remove(attrIndex); } ArrayList<String> resultAttr = new ArrayList<String>(attributeList); resultAttr.remove(attrIndex); leafNode = createDT(resultData,resultAttr); } node.getChildTreeNode().add(leafNode); node.getPathName().add(entry.getKey()); } } return node; } class TreeNode{ private String attributeValue; private List<TreeNode> childTreeNode; private List<String> pathName; private String targetFunValue; private String nodeName; public TreeNode(String nodeName){ this.nodeName = nodeName; this.childTreeNode = new ArrayList<TreeNode>(); this.pathName = new ArrayList<String>(); } public TreeNode(){ this.childTreeNode = new ArrayList<TreeNode>(); this.pathName = new ArrayList<String>(); } public String getAttributeValue() { return attributeValue; } public void setAttributeValue(String attributeValue) { this.attributeValue = attributeValue; } public List<TreeNode> getChildTreeNode() { return childTreeNode; } public void setChildTreeNode(List<TreeNode> childTreeNode) { this.childTreeNode = childTreeNode; } public String getTargetFunValue() { return targetFunValue; } public void setTargetFunValue(String targetFunValue) { this.targetFunValue = targetFunValue; } public String getNodeName() { return nodeName; } public void setNodeName(String nodeName) { this.nodeName = nodeName; } public List<String> getPathName() { return pathName; } public void setPathName(List<String> pathName) { this.pathName = pathName; } } }
增益率计算类(取log的时候底用的是e,没用2)
package C45Test; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; //C 4.5 实现 public class InfoGain { private List<ArrayList<String>> data; private List<String> attribute; public InfoGain(List<ArrayList<String>> data,List<String> attribute){ this.data = new ArrayList<ArrayList<String>>(); for(int i=0;i<data.size();i++){ List<String> temp = data.get(i); ArrayList<String> t = new ArrayList<String>(); for(int j=0;j<temp.size();j++){ t.add(temp.get(j)); } this.data.add(t); } this.attribute = new ArrayList<String>(); for(int k=0;k<attribute.size();k++){ this.attribute.add(attribute.get(k)); } /*this.data = data; this.attribute = attribute;*/ } //获得熵 public double getEntropy(){ Map<String,Long> targetValueMap = getTargetValue(); Set<String> targetkey = targetValueMap.keySet(); double entropy = 0.0; for(String key : targetkey){ double p = MathUtils.div((double)targetValueMap.get(key), (double)data.size()); entropy += (-1) * p * Math.log(p); } return entropy; } //获得InfoA public double getInfoAttribute(int attributeIndex){ Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex); double infoA = 0.0; for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){ int size = data.size(); double attributeP = MathUtils.div((double)entry.getValue() , (double) size); Map<String,Long> targetValueMap = getAttributeValueTargetValue(entry.getKey(),attributeIndex); long totalCount = 0L; for(Map.Entry<String, Long> entryValue :targetValueMap.entrySet()){ totalCount += entryValue.getValue(); } double valueSum = 0.0; for(Map.Entry<String, Long> entryTargetValue : targetValueMap.entrySet()){ double p = MathUtils.div((double)entryTargetValue.getValue(), (double)totalCount); valueSum += Math.log(p) * p; } infoA += (-1) * attributeP * valueSum; } return infoA; } //得到属性值在决策空间的比例 public Map<String,Long> getAttributeValueTargetValue(String attributeName,int attributeIndex){ Map<String,Long> targetValueMap = new HashMap<String,Long>(); Iterator<ArrayList<String>> iterator = data.iterator(); while(iterator.hasNext()){ List<String> tempList = iterator.next(); if(attributeName.equalsIgnoreCase(tempList.get(attributeIndex))){ int size = tempList.size(); String key = tempList.get(size - 1); Long value = targetValueMap.get(key); targetValueMap.put(key, value != null ? ++value :1L); } } return targetValueMap; } //得到属性在决策空间上的数量 public Map<String,Long> getAttributeValue(int attributeIndex){ Map<String,Long> attributeValueMap = new HashMap<String,Long>(); for(ArrayList<String> note : data){ String key = note.get(attributeIndex); Long value = attributeValueMap.get(key); attributeValueMap.put(key, value != null ? ++value :1L); } return attributeValueMap; } public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){ List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>(); Iterator<ArrayList<String>> iterator = data.iterator(); for(;iterator.hasNext();){ ArrayList<String> templist = iterator.next(); if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){ ArrayList<String> temp = (ArrayList<String>) templist.clone(); resultData.add(temp); } } return resultData; } //获得增益率 public double getGainRatio(int attributeIndex){ return MathUtils.div(getGain(attributeIndex), getSplitInfo(attributeIndex)); } //获得增益量 public double getGain(int attributeIndex){ return getEntropy() - getInfoAttribute(attributeIndex); } //得到惩罚因子 public double getSplitInfo(int attributeIndex){ Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex); double splitA = 0.0; for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){ int size = data.size(); double attributeP = MathUtils.div((double)entry.getValue() , (double) size); splitA += attributeP * Math.log(attributeP) * (-1); } return splitA; } //得到目标函数在当前集合范围内的离散的值 public Map<String,Long> getTargetValue(){ Map<String,Long> targetValueMap = new HashMap<String,Long>(); Iterator<ArrayList<String>> iterator = data.iterator(); while(iterator.hasNext()){ List<String> tempList = iterator.next(); String key = tempList.get(tempList.size() - 1); Long value = targetValueMap.get(key); targetValueMap.put(key, value != null ? ++value : 1L); } return targetValueMap; } //获得TARGET值 public static List<String> getTarget(List<ArrayList<String>> data){ List<String> list = new ArrayList<String>(); for(ArrayList<String> temp : data){ int index = temp.size() -1; String value = temp.get(index); list.add(value); } return list; } //判断当前纯度是否100% public static String IsPure(List<String> list){ Set<String> set = new HashSet<String>(); for(String name :list){ set.add(name); } if(set.size() > 1) return null; Iterator<String> iterator = set.iterator(); return iterator.next(); } }
测试类,数据集读取以上的分别放到2个List中。
package C45Test; import java.util.ArrayList; import java.util.List; import C45Test.DecisionTree.TreeNode; public class MainC45 { private static final List<ArrayList<String>> dataList = new ArrayList<ArrayList<String>>(); private static final List<String> attributeList = new ArrayList<String>(); public static void main(String args[]){ DecisionTree dt = new DecisionTree(); TreeNode node = dt.createDT(configData(),configAttribute()); System.out.println(); } }
大数运算工具类
package C45Test; import java.math.BigDecimal; public abstract class MathUtils { //默认余数长度 private static final int DIV_SCALE = 10; //受限于DOUBLE长度 public static double add(double value1,double value2){ BigDecimal big1 = new BigDecimal(String.valueOf(value1)); BigDecimal big2 = new BigDecimal(String.valueOf(value2)); return big1.add(big2).doubleValue(); } //大数加法 public static double add(String value1,String value2){ BigDecimal big1 = new BigDecimal(value1); BigDecimal big2 = new BigDecimal(value2); return big1.add(big2).doubleValue(); } public static double div(double value1,double value2){ BigDecimal big1 = new BigDecimal(String.valueOf(value1)); BigDecimal big2 = new BigDecimal(String.valueOf(value2)); return big1.divide(big2,DIV_SCALE,BigDecimal.ROUND_HALF_UP).doubleValue(); } public static double mul(double value1,double value2){ BigDecimal big1 = new BigDecimal(String.valueOf(value1)); BigDecimal big2 = new BigDecimal(String.valueOf(value2)); return big1.multiply(big2).doubleValue(); } public static double sub(double value1,double value2){ BigDecimal big1 = new BigDecimal(String.valueOf(value1)); BigDecimal big2 = new BigDecimal(String.valueOf(value2)); return big1.subtract(big2).doubleValue(); } public static double returnMax(double value1, double value2) { BigDecimal big1 = new BigDecimal(value1); BigDecimal big2 = new BigDecimal(value2); return big1.max(big2).doubleValue(); } }