决策树ID3算法的java实现(基本适用所有的ID3)
已知:流感训练数据集,预定义两个类别;
求:用ID3算法建立流感的属性描述决策树
流感训练数据集
No. |
头痛 |
肌肉痛 |
体温 |
患流感 |
1 |
是(1) |
是(1) |
正常(0) |
否(0) |
2 |
是(1) |
是(1) |
高(1) |
是(1) |
3 |
是(1) |
是(1) |
很高(2) |
是(1) |
4 |
否(0) |
是(1) |
正常(0) |
否(0) |
5 |
否(0) |
否(0) |
高(1) |
否(0) |
6 |
否(0) |
是(1) |
很高(2) |
是(1) |
7 |
是(1) |
否(0) |
高(1) |
是(1) |
原理分析:
在决策树的每一个非叶子结点划分之前,先计算每一个属性所带来的信息增益,选择最大信息增益的属性来划分,因为信息增益越大,区分样本的能力就越强,越具有代表性其中。
信息熵计算:
信息增益:
计算的结果(草稿上的字丑别喷):
--------------------------------------------------------------------------------------------------------------------------------------------
*************************************************************************************************************
************************实现*********************************************
package ID3Tree; import java.util.Comparator;; @SuppressWarnings("rawtypes") public class Comparisons implements Comparator { public int compare(Object a, Object b) throws ClassCastException{ String str1 = (String)a; String str2 = (String)b; return str1.compareTo(str2); } }
package ID3Tree; public class Entropy { //信息熵 public static double getEntropy(int x, int total) { if (x == 0) { return 0; } double x_pi = getShang(x,total); return -(x_pi*Logs(x_pi)); } public static double Logs(double y) { return Math.log(y) / Math.log(2); } public static double getShang(int x, int total) { return x * Double.parseDouble("1.0") / total; } }
package ID3Tree; public class TreeNode { //父节点 TreeNode parent; //指向父节点的属性 String parentAttribute; String nodeName; String[] attributes; TreeNode[] childNodes; }
package ID3Tree; import java.util.*; public class UtilID3 { TreeNode root; private boolean[] flag; //训练集 private Object[] trainArrays; //节点索引 private int nodeIndex; public static void main(String[] args) { //初始化训练集数组 Object[] arrays = new Object[]{ new String[]{"是","是","正常","否"}, new String[]{"是","是","高","是"}, new String[]{"是","是","很高","是"}, new String[]{"否","是","正常","否"}, new String[]{"否","否","高","否"}, new String[]{"否","是","很高","是"}, new String[]{"是","否","高","是"}}; UtilID3 ID3Tree = new UtilID3(); ID3Tree.create(arrays, 3); } //创建 public void create(Object[] arrays, int index) { this.trainArrays = arrays; initial(arrays, index); createDTree(arrays); printDTree(root); } //初始化 public void initial(Object[] dataArray, int index) { this.nodeIndex = index; //数据初始化 this.flag = new boolean[((String[])dataArray[0]).length]; for (int i = 0; i<this.flag.length; i++) { if (i == index) { this.flag[i] = true; } else { this.flag[i] = false; } } } //创建决策树 public void createDTree(Object[] arrays) { Object[] ob = getMaxGain(arrays); if (this.root == null) { this.root = new TreeNode(); root.parent = null; root.parentAttribute = null; root.attributes = getAttributes(((Integer)ob[1]).intValue()); root.nodeName = getNodeName(((Integer)ob[1]).intValue()); root.childNodes = new TreeNode[root.attributes.length]; insert(arrays, root); } } //插入决策树 public void insert(Object[] arrays, TreeNode parentNode) { String[] attributes = parentNode.attributes; for (int i = 0; i < attributes.length; i++) { Object[] Arrays = pickUpAndCreateArray(arrays, attributes[i],getNodeIndex(parentNode.nodeName)); Object[] info = getMaxGain(Arrays); double gain = ((Double)info[0]).doubleValue(); if (gain != 0) { int index = ((Integer)info[1]).intValue(); TreeNode currentNode = new TreeNode(); currentNode.parent = parentNode; currentNode.parentAttribute = attributes[i]; currentNode.attributes = getAttributes(index); currentNode.nodeName = getNodeName(index); currentNode.childNodes = new TreeNode[currentNode.attributes.length]; parentNode.childNodes[i] = currentNode; insert(Arrays, currentNode); } else { TreeNode leafNode = new TreeNode(); leafNode.parent = parentNode; leafNode.parentAttribute = attributes[i]; leafNode.attributes = new String[0]; leafNode.nodeName = getLeafNodeName(Arrays); leafNode.childNodes = new TreeNode[0]; parentNode.childNodes[i] = leafNode; } } } //输出 public void printDTree(TreeNode node) { System.out.println(node.nodeName); TreeNode[] childs = node.childNodes; for (int i = 0; i < childs.length; i++) { if (childs[i] != null) { System.out.println("如果:"+childs[i].parentAttribute); printDTree(childs[i]); } } } //剪取数组 public Object[] pickUpAndCreateArray(Object[] arrays, String attribute, int index) { List<String[]> list = new ArrayList<String[]>(); for (int i = 0; i < arrays.length; i++) { String[] strs = (String[])arrays[i]; if (strs[index].equals(attribute)) { list.add(strs); } } return list.toArray(); } //取得节点名 public String getNodeName(int index) { String[] strs = new String[]{"头痛","肌肉痛","体温","患流感"}; for (int i = 0; i < strs.length; i++) { if (i == index) { return strs[i]; } } return null; } //取得叶子节点名 public String getLeafNodeName(Object[] arrays) { if (arrays != null && arrays.length > 0) { String[] strs = (String[])arrays[0]; return strs[nodeIndex]; } return null; } //取得节点索引 public int getNodeIndex(String name) { String[] strs = new String[]{"头痛","肌肉痛","体温","患流感"}; for (int i = 0; i < strs.length; i++) { if (name.equals(strs[i])) { return i; } } return -1; } //得到最大信息增益 public Object[] getMaxGain(Object[] arrays) { Object[] result = new Object[2]; double gain = 0; int index = -1; for (int i = 0; i<this.flag.length; i++) { if (!this.flag[i]) { double value = gain(arrays, i); if (gain < value) { gain = value; index = i; } } } result[0] = gain; result[1] = index; if (index != -1) { this.flag[index] = true; } return result; } //取得属性数组 public String[] getAttributes(int index) { @SuppressWarnings("unchecked") TreeSet<String> set = new TreeSet<String>(new Comparisons()); for (int i = 0; i<this.trainArrays.length; i++) { String[] strs = (String[])this.trainArrays[i]; set.add(strs[index]); } String[] result = new String[set.size()]; return set.toArray(result); } //计算信息增益 public double gain(Object[] arrays, int index) { String[] playBalls = getAttributes(this.nodeIndex); int[] counts = new int[playBalls.length]; for (int i = 0; i<counts.length; i++) { counts[i] = 0; } for (int i = 0; i<arrays.length; i++) { String[] strs = (String[])arrays[i]; for (int j = 0; j<playBalls.length; j++) { if (strs[this.nodeIndex].equals(playBalls[j])) { counts[j]++; } } } double entropyS = 0; for (int i = 0;i <counts.length; i++) { entropyS = entropyS + Entropy.getEntropy(counts[i], arrays.length); } String[] attributes = getAttributes(index); double total = 0; for (int i = 0; i<attributes.length; i++) { total = total + entropy(arrays, index, attributes[i], arrays.length); } return entropyS - total; } public double entropy(Object[] arrays, int index, String attribute, int totals) { String[] playBalls = getAttributes(this.nodeIndex); int[] counts = new int[playBalls.length]; for (int i = 0; i < counts.length; i++) { counts[i] = 0; } for (int i = 0; i < arrays.length; i++) { String[] strs = (String[])arrays[i]; if (strs[index].equals(attribute)) { for (int k = 0; k<playBalls.length; k++) { if (strs[this.nodeIndex].equals(playBalls[k])) { counts[k]++; } } } } int total = 0; double entropy = 0; for (int i = 0; i < counts.length; i++) { total = total +counts[i]; } for (int i = 0; i < counts.length; i++) { entropy = entropy + Entropy.getEntropy(counts[i], total); } return Entropy.getShang(total, totals)*entropy; } }
作者:Honey_Badger —— 觉得这文章好,点一下左下角
出处:http://tk55.cnblogs.com/
posted on 2016-12-28 22:57 Honey_Badger 阅读(9074) 评论(2) 编辑 收藏 举报