Python 机器学习 数据集分布可视化
Python 的机器学习项目中,可视化是理解数据、模型和预测结果的重要工具。通过可视化可以观察数据集的分布情况,了解数据的特征和规律,可以评估模型的性能,发现模型的优缺点,分析预测结果,解释模型的预测过程。可视化数据集的分布和预测结果是整个过程中一个重要的步骤。通常可视化可以用Seaborn实现,它是基于 Matplotlib 的高级绘图库,提供了一些更高级的绘图功能。
参考文档:
1、加载数据集
load_iris()
是scikit-learn库中的一个函数,用于加载一个著名的数据集,即鸢尾花(Iris)数据集。数据集通常用于机器学习和统计分类技术的示例、测试和实验。鸢尾花数据集包含了三种鸢尾花(Iris setosa、Iris virginica和Iris versicolor)的150个样本。每个样本有四个特征:萼片长度、萼片宽度、花瓣长度和花瓣宽度,这些特征的单位都是厘米。目标变量是花的种类。数据集常用于分类算法的教学和测试,特别是对于新手来说,它是理解机器学习概念的一个很好的入门数据集。可以用于各种分类算法,包括最简单的如K-近邻(KNN)算法,以及更复杂的如支持向量机(SVM)和神经网络。
from sklearn.datasets import load_iris # 加载数据集 iris = load_iris() # 特征矩阵,iris.data包含了150个样本的四个特征值 x = iris.data print(x) # 目标向量,iris.target包含了相应的种类标签(0, 1, 2分别代表三种不同的鸢尾花)。 y = iris.target print(y) # 特征名称 feature_names = iris.feature_names print(feature_names) # 目标名称(花的种类) target_names = iris.target_names print(target_names)
2、seaborn.lmplot()的使用
Seaborn 的 lmplot()
函数是用于绘制线性回归模型的强大工具,它结合了 regplot() 和 FacetGrid。这个函数适用于绘制数据集中变量间线性关系的图形,尤其是探索两个连续变量(或一个连续和一个分类变量)之间的关系。lmplot()
是一个功能强大的工具,适用于探索和呈现变量间的线性关系,特别是在数据集包含分类变量时。常用参数如下,
参数 |
描述 |
x |
数据框架中的变量名称,将在 x 轴上绘制。 |
y |
数据框架中的变量名称,将在 y 轴上绘制。 |
hue |
数据框架中的分类变量,不同类别以不同颜色显示。 |
data |
数据源,通常是 Pandas 的 DataFrame。 |
palette |
设置不同类别的颜色。 |
col |
用于在不同列展示数据框架中一个额外分类变量的不同级别。 |
row |
用于在不同行展示数据框架中一个额外分类变量的不同级别。 |
markers |
指定不同类别的数据点的标记类型。 |
fit_reg |
布尔值,控制是否绘制回归模型 (对于 KNN,通常为 False)。 |
scatter_kws |
传递额外的关键字参数到底层 Matplotlib 函数, 控制散点的样式。 |
line_kws |
传递额外的关键字参数到底层 Matplotlib 函数, 控制线条的样式。 |
height |
每个面板的高度大小。 |
aspect |
每个面板的宽高比。 |
使用代码:
import seaborn as sns import matplotlib.pyplot as plt import pandas as pd from sklearn.datasets import load_iris # 加载鸢尾花数据集 #iris = sns.load_dataset('iris') #加载报错,可以直接使用sklearn.datasets的load_iris来加载数据集 # 加载鸢尾花数据集 iris = load_iris() # 创建DataFrame iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names) iris_df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names) # 使用 lmplot() 函数绘制图表 sns.lmplot( x="sepal length (cm)", # x 轴变量 y="petal length (cm)", # y 轴变量 hue="species", # 数据分类变量 data=iris_df, # 数据源 palette="Set1", # 为不同的 species 设置不同的颜色 markers=["o", "s", "D"], # 为不同的 species 设置不同的标记 height=5, # 图表高度 aspect=1.5, # 图表宽高比 fit_reg=False, # 关闭线性回归拟合线,因为我们更关注数据分布 scatter_kws={"s": 50, "alpha": 0.8} # 设置散点的大小和透明度 ) # 添加标题 plt.title("cjavapy") plt.draw() # 显示图形 plt.show()
3、数据集分布可视化
使用 Seaborn 的 lmplot()
创建的鸢尾花数据集的散点图。展示数据的分布而非探究变量间的线性关系。可视化是理解数据、模型和预测结果的重要工具。散点图用于展示数据点的分布情况,适用于数值型数据。
使用示例: