python实现决策树

1.决策树的简介

  http://www.cnblogs.com/lufangtao/archive/2013/05/30/3103588.html

2.决策是实现的伪代码

“读入训练数据”
“找出每个属性的可能取值”

“递归调用建立决策树的函数”
    “para:节点,剩余样例,剩余属性”

    if “剩余属性个数为0"
        return most_of_result
    else if “剩余样例都属于同一个分类(yes/no)"    
        return yes/no
    else:
           ”对于每一个剩余属性,计算该属性的熵增“,并找到熵增最大的对应的属性,即为最佳分类属性”
        “按照最佳分类属性分类,对于每个分支,递归调用建立函数,最终得到整个决策树”

3.python数据结构设计

  1.数据集:用于存储二维的训练数据training_data

    二维的list数组,对于二维的list要取得某一列的数据,可以用zip(*dataset)[num]

  2.属性集合:用于存储属性的名称attri_name_set

    一维的list

  3.属性的可能取值:存储各个属性的可能取值状态

    dict+set:dict的key是属性的名称,value是set类型,这样可以保证不会有重复
    新建set类型:attri[i] = set()

  4.树的节点定义

  class Dtree_node(object):
      def __init__(self):
          self.attriname = None
          self.sub_node = {} #子节点为dict类型
  子节点的类型为dict,key是属性的不同取值,value是对应的子节点

 

 

4.code

# -*- coding: utf-8 -*-
from __future__ import division
import  math

__author__ = 'Jiayin'
#date:2016-3-28
#决策树的实现,从test.txt中读入训练数据,
#全局变量 training_data = [] #数据集(二维list表) attri = {} #属性集(dict+set) attri_name_set = [] class Dtree_node(object): def __init__(self): self.attriname = None self.sub_node = {} #子节点为dict类型 root = Dtree_node() #输入数据 def get_input(): #属性集合 属性是dict结构,key为属性名(str),value是该属性可以取到的值类型为set #第一个属性通常为编号,最后一个属性通常为决策结果,取值只有yes/no global attri global attri_name_set file_read = open("test.txt") line = file_read.readline().split() attri_name_set = line[:] #print line for i in line: attri[i] = set() line = file_read.readline().split() #读入数据,并计算每个属性的可能取值 while line: training_data.append(line) for i in range(1,len(line)-1): attri[attri_name_set[i]].add(line[i]) line = file_read.readline().split() #取most_of _result def getmost(dataset_result): p = 0 n = 0 for i in dataset_result: if i == 'yes': p+=1 else: n+=1 return 'yes' if p>n else 'no' #计算熵 def cal_entropy(dataset_result): num_yes = 0 num_no = 0 for i in dataset_result: if i == 'yes': num_yes +=1 else: num_no += 1 if num_no == 0 or num_yes == 0: return 0 total_num = num_no +num_yes per_yes = num_yes/total_num per_no = num_no/total_num return -per_yes*math.log(per_yes,2)-per_no*math.log(per_no,2) #计算某个属性的熵增 #参数 :数据集和属性名,初始熵 def cal_incr_entr_attri(data_set,attriname,init_entropy): global attri global attri_name_set incr_entr = init_entropy attri_index = attri_name_set.index(attriname) #将该属性的不同取值提取出来,并分别计算熵,求出熵增 for i in attri[attriname]: #new_data = data_set[:] new_data = filter(lambda x: True if x[attri_index] == i else False ,data_set) if len(new_data)==0: continue num = cal_entropy(zip(*new_data)[-1]) incr_entr -= len(new_data)/len(data_set)*num return incr_entr #判断是否剩余数据集都是一个结果 def if_all_label(dataset_result, result): #result = dataset_result[0] for i in range(0,len(dataset_result)): if dataset_result[i] <> result: break return False if dataset_result[i]<>result else True #建立决策树 #参数:root:节点 dataset:剩下的数据集 attriset:剩下的属性集 def create_Dtree(root_node , data_set , attri_set): global attri global attri_name_set ''' #如果当前数据集为空,应该返回上一层的most_of_result,此处要修改 if len(data_set)==0: return None''' #考虑如果剩余属性集为空,则返回most_of_result if len(attri_set) == 0: print zip(*data_set) root_node.attriname = getmost(zip(*data_set)[-1]) #zip(*dataset)[-1]表示取出最后一列,也就是yes/no那一列 return None #考虑如果剩余的数据集都是一个结果的话,返回这个结果 elif if_all_label(zip(*data_set)[-1],'yes'): root_node.attriname = 'yes' return None elif if_all_label(zip(*data_set)[-1],'no'): root_node.attriname = 'no' return None #print zip(*data_set) init_entropy = cal_entropy(zip(*data_set)[-1])#计算初始熵 max_entropy = 0 for i in attri_set: entropy = cal_incr_entr_attri(data_set,i,init_entropy) if entropy > max_entropy: max_entropy = entropy best_attri = i new_attri = attri_set[:] root_node.attriname = best_attri attri_index = attri_name_set.index(best_attri) for attri_value in attri[best_attri]: #new_data = data_set[:] new_data = filter(lambda x: True if x[attri_index] == attri_value else False ,data_set) root_node.sub_node[attri_value] = Dtree_node() #如果该分支下面的数据集个数为0,则采用父节点的most_of_result if len(new_data)==0: root_node.sub_node[attri_value].attriname = getmost(zip(*data_set)[-1]) else: create_Dtree(root_node.sub_node[attri_value],new_data,new_attri) def print_Dtree(Root_node,layer): print Root_node.attriname count = 1 if len(Root_node.sub_node) > 0: for sub in Root_node.sub_node.keys(): for i in range(layer): print "| ", print "|----%10s---"%sub, assert isinstance(layer, object) print_Dtree(Root_node.sub_node[sub] , layer+1) #count += 1 def main(): global root global attri_name_set get_input()#输入 attri_set = attri_name_set[1:-1]#提取出要分类的属性 create_Dtree(root,training_data,attri_set)#创建决策树 print_Dtree(root,0)#打印决策树 main()

 

posted @ 2016-03-30 16:12  928pjy  阅读(1721)  评论(0编辑  收藏  举报