使用 ID3 对 Titanic 进行决策树分类

原创转载请注明出处:https://www.cnblogs.com/agilestyle/p/12722688.html

 

过程划分

 

数据加载

import graphviz
import numpy as np
import pandas as pd
from sklearn import tree
from sklearn.feature_extraction import DictVectorizer
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier

# 数据加载
train_data = pd.read_csv(r'/data/Titanic/train.csv')
test_data = pd.read_csv(r'/data/Titanic/train.csv')

 

数据探索

# 数据探索
print('-' * 30)
print(train_data.info())
print('-' * 30)
print(train_data.describe())
print('-' * 30)
print(train_data.describe(include=['O']))
print('-' * 30)
print(train_data.head())
print('-' * 30)
print(train_data.tail())

Console Output

------------------------------
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
PassengerId    891 non-null int64
Survived       891 non-null int64
Pclass         891 non-null int64
Name           891 non-null object
Sex            891 non-null object
Age            714 non-null float64
SibSp          891 non-null int64
Parch          891 non-null int64
Ticket         891 non-null object
Fare           891 non-null float64
Cabin          204 non-null object
Embarked       889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
None
------------------------------
       PassengerId    Survived      Pclass         Age       SibSp  \
count   891.000000  891.000000  891.000000  714.000000  891.000000   
mean    446.000000    0.383838    2.308642   29.699118    0.523008   
std     257.353842    0.486592    0.836071   14.526497    1.102743   
min       1.000000    0.000000    1.000000    0.420000    0.000000   
25%     223.500000    0.000000    2.000000   20.125000    0.000000   
50%     446.000000    0.000000    3.000000   28.000000    0.000000   
75%     668.500000    1.000000    3.000000   38.000000    1.000000   
max     891.000000    1.000000    3.000000   80.000000    8.000000   

            Parch        Fare  
count  891.000000  891.000000  
mean     0.381594   32.204208  
std      0.806057   49.693429  
min      0.000000    0.000000  
25%      0.000000    7.910400  
50%      0.000000   14.454200  
75%      0.000000   31.000000  
max      6.000000  512.329200  
------------------------------
                                                   Name   Sex  Ticket Cabin  \
count                                               891   891     891   204   
unique                                              891     2     681   147   
top     Lobb, Mrs. William Arthur (Cordelia K Stanlick)  male  347082    G6   
freq                                                  1   577       7     4   

       Embarked  
count       889  
unique        3  
top           S  
freq        644  
------------------------------
   PassengerId  Survived  Pclass  \
0            1         0       3   
1            2         1       1   
2            3         1       3   
3            4         1       1   
4            5         0       3   

                                                Name     Sex   Age  SibSp  \
0                            Braund, Mr. Owen Harris    male  22.0      1   
1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   
2                             Heikkinen, Miss. Laina  female  26.0      0   
3       Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   
4                           Allen, Mr. William Henry    male  35.0      0   

   Parch            Ticket     Fare Cabin Embarked  
0      0         A/5 21171   7.2500   NaN        S  
1      0          PC 17599  71.2833   C85        C  
2      0  STON/O2. 3101282   7.9250   NaN        S  
3      0            113803  53.1000  C123        S  
4      0            373450   8.0500   NaN        S  
------------------------------
     PassengerId  Survived  Pclass                                      Name  \
886          887         0       2                     Montvila, Rev. Juozas   
887          888         1       1              Graham, Miss. Margaret Edith   
888          889         0       3  Johnston, Miss. Catherine Helen "Carrie"   
889          890         1       1                     Behr, Mr. Karl Howell   
890          891         0       3                       Dooley, Mr. Patrick   

        Sex   Age  SibSp  Parch      Ticket   Fare Cabin Embarked  
886    male  27.0      0      0      211536  13.00   NaN        S  
887  female  19.0      0      0      112053  30.00   B42        S  
888  female   NaN      1      2  W./C. 6607  23.45   NaN        S  
889    male  26.0      0      0      111369  30.00  C148        C  
890    male  32.0      0      0      370376   7.75   NaN        Q  

 

数据清洗

# 数据清洗
# 使用平均年龄来填充年龄中的 nan 值
train_data['Age'].fillna(train_data['Age'].mean(), inplace=True)
test_data['Age'].fillna(test_data['Age'].mean(), inplace=True)
# 使用票价的均值填充票价中的 nan 值
train_data['Fare'].fillna(train_data['Fare'].mean(), inplace=True)
test_data['Fare'].fillna(test_data['Fare'].mean(), inplace=True)
# 使用登录最多的港口来填充登录港口的 nan 值
print(train_data['Embarked'].value_counts())
train_data['Embarked'].fillna('S', inplace=True)
test_data['Embarked'].fillna('S', inplace=True)

 

特征选择

特征选择是分类器的关键。特征选择不同,得到的分类器也不同。可以通过数据探索发现来选择哪些特征做生存的预测。PassengerId 为乘客编号,对分类没有作用,可以放弃;Name 为乘客姓名,对分类没有作用,可以放弃;Cabin 字段缺失值太多,可以放弃;Ticket 字段为船票号码,杂乱无章且无规律,可以放弃。其余的字段包括:Pclass、Sex、Age、SibSp、Parch 和 Fare,这些属性分别表示了乘客的船票等级、性别、年龄、亲戚数量以及船票价格,可能会和乘客的生存预测分类有关系。具体是什么关系,可以交给分类器来处理。

# 特征选择
features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']
train_features = train_data[features]
train_labels = train_data['Survived']
test_features = test_data[features]
dvec = DictVectorizer(sparse=False)
# fit_transform 函数将特征向量转化为特征值矩阵
train_features = dvec.fit_transform(train_features.to_dict(orient='record'))
print(dvec.feature_names_)

Console Output

['Age', 'Embarked=C', 'Embarked=Q', 'Embarked=S', 'Fare', 'Parch', 'Pclass', 'Sex=female', 'Sex=male', 'SibSp']
(891, 10)

可以看到原本是一列的 Embarked,变成了“Embarked=C”、“Embarked=Q”、“Embarked=S”三列。Sex 列变成了“Sex=female”、“Sex=male”两列。

这样 train_features 特征矩阵就包括 10 个特征值(列),以及 891 个样本(行),即 891 行,10 列的特征矩阵。

Note: fit_transform 和 transform 的区别

  • fit 从一个训练集中学习模型参数,其中就包括了归一化时用到的均值,标准偏差等,可以理解为一个训练过程。
  • transform: 在fit的基础上,对数据进行标准化,降维,归一化等数据转换操作。
  • fit_transform: 将模型训练和转化合并到一起,训练样本先做fit,得到mean,standard deviation,然后将这些参数用于transform(归一化训练数据),使得到的训练数据是归一化的,而测试数据只需要在原先fit得到的mean,std上来做归一化就行了,所以用transform就行了。

需要注意的是,transform和fit_transform虽然结果相同,但是不能互换。因为fit_transform只是 fit+transform两个步骤合并的简写。而各种分类算法都需要先fit,然后再进行transform。所以如果把fit_transform替换为transform可能会报错。

 

建模训练

# 决策树模型
# 构造 ID3 决策树
clf = DecisionTreeClassifier(criterion='entropy')

# 决策树训练
clf.fit(train_features, train_labels)

# 模型预测评估
test_features=dvec.transform(test_features.to_dict(orient='record'))

# 决策树预测
pred_labels = clf.predict(test_features)

# 决策树准确率
from sklearn.metrics import accuracy_score

train_score = clf.score(train_features, train_labels)
test_labels = test_data['Survived']
test_score = accuracy_score(test_labels, pred_labels)

print(u'train_score 准确率为 %.4lf' % train_score)
print(u'test_score 准确率为 %.4lf' % test_score)

Console Output

train_score 准确率为 0.9820
test_score 准确率为 0.9820

Note: 

用训练集做训练,再用训练集自身做准确率评估自然会很高。但这样得出的准确率并不能代表决策树分类器的准确率。因为没有测试集的实际结果,因此无法用测试集的预测结果与实际结果做对比。如果使用 score 函数对训练集的准确率进行统计,正确率会接近于 100%(如上结果为 98.2%),无法对分类器的在实际环境下做准确率的评估。模型准确率需要考虑是否有测试集的实际结果可以做对比,当测试集没有真实结果可以对比时,需要使用 K 折交叉验证 cross_val_score。

 

K 折交叉验证

交叉验证是一种常用的验证分类准确率的方法,原理是拿出大部分样本进行训练,少量的用于分类器的验证。K 折交叉验证,就是做 K 次交叉验证,每次选取 K 分之一的数据作为验证,其余作为训练。轮流 K 次,取平均值。

K 折交叉验证的原理

  1. 将数据集平均分割成 K 个等份;
  2. 使用 1 份数据作为测试数据,其余作为训练数据;
  3. 计算测试准确率;
  4. 使用不同的测试集,重复 2、3 步骤。

在 sklearn 的 model_selection 模型选择中提供了 cross_val_score 函数。cross_val_score 函数中的参数 cv 代表对原始数据划分成多少份,也就是 K 值,一般建议 K 值取 10,因此可以设置 CV=10,可以对比下 score 和 cross_val_score 两种函数的正确率的评估结果。

# K 折交叉验证统计决策树准确率
cv_score = np.mean(cross_val_score(clf, train_features, train_labels, cv=10))
print(u'cross_val_score 准确率为 %.4lf' % cv_score)

Console Output (每次运行结果可能会有不同)

cross_val_score 准确率为 0.7746

 

决策树可视化

dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.view('titanic')

Note:如果提示 graphviz 不可用或者引用不到等错误,运行如下命令进行安装

conda install graphviz
conda install python-graphviz
conda install pydot

执行后,可以得到下面的图示

 

Reference

https://github.com/cystanford/Titanic_Data

https://time.geekbang.org/column/article/79072

https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.describe.html

https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.fillna.html

https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_dict.html

https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.DictVectorizer.html

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_score.html

 

posted @ 2020-04-17 21:47  李白与酒  阅读(756)  评论(0编辑  收藏  举报