基于sklearn 实现决策树(含最简代码,复杂源码:预测带不带眼镜)

最简代码:

1 #简单的决策树分类
2 from sklearn import tree
3 features = [[300,2],[450,2],[200,8],[150,9]]
4 labels = ['apple','apple','orange','orange']
5 clf = tree.DecisionTreeClassifier()
6 clf = clf.fit(features,labels)
7 print(clf.predict([[400,6]]))

预测代码:

数据集下载地址

代码:

复制代码
 1 # -*- coding: UTF-8 -*-
 2 from sklearn.preprocessing import LabelEncoder, OneHotEncoder
 3 from sklearn.externals.six import StringIO
 4 from sklearn import tree
 5 import pandas as pd
 6 import numpy as np
 7 import pydotplus
 8 
 9 if __name__ == '__main__':
10     with open('data\lenses.txt', 'r') as fr:                                        #加载文件
11         lenses = [inst.strip().split('\t') for inst in fr.readlines()]        #处理文件
12     lenses_target = []                                                        #提取每组数据的类别,保存在列表里
13     for each in lenses:
14         lenses_target.append(each[-1])
15 
16     lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']            #特征标签
17     lenses_list = []                                                        #保存lenses数据的临时列表
18     lenses_dict = {}                                                        #保存lenses数据的字典,用于生成pandas
19     for each_label in lensesLabels:                                            #提取信息,生成字典
20         for each in lenses:
21             lenses_list.append(each[lensesLabels.index(each_label)])
22         lenses_dict[each_label] = lenses_list
23         lenses_list = []
24     # print(lenses_dict)                                                        #打印字典信息
25     lenses_pd = pd.DataFrame(lenses_dict)                                    #生成pandas.DataFrame
26     print(lenses_pd)                                                        #打印pandas.DataFrame
27     le = LabelEncoder()                                                        #创建LabelEncoder()对象,用于序列化
28     for col in lenses_pd.columns:                                            #序列化
29         lenses_pd[col] = le.fit_transform(lenses_pd[col])
30     print(lenses_pd)                                                        #打印编码信息
31 
32     clf = tree.DecisionTreeClassifier(max_depth = 4)                        #创建DecisionTreeClassifier()类
33     clf = clf.fit(lenses_pd.values.tolist(), lenses_target)                    #使用数据,构建决策树
34     print(lenses_target)
35     print(clf.predict([[1,1,1,0]]))                    #预测
预测眼镜
复制代码

 

posted @   博二爷  阅读(788)  评论(0编辑  收藏  举报
编辑推荐:
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· C#/.NET/.NET Core技术前沿周刊 | 第 29 期(2025年3.1-3.9)
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
历史上的今天:
2019-03-03 课堂测试总结
点击右上角即可分享
微信分享提示