博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

04机器学习实战之朴素贝叶斯scikit-learn实现

Posted on 2019-03-08 17:00  心默默言  阅读(200)  评论(0编辑  收藏  举报
In [8]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScaler
from sklearn.naive_bayes import GaussianNB, MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier


def iris_type(s):
    it = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
    return it[s]
In [15]:
data = np.loadtxt('D:\\mlInAction\\8.iris.data', encoding='utf-8', dtype=float, delimiter=',',
                  converters={4: iris_type})
data
Out[15]:
array([[5.1, 3.5, 1.4, 0.2, 0. ],
       [4.9, 3. , 1.4, 0.2, 0. ],
       [4.7, 3.2, 1.3, 0.2, 0. ],
       [4.6, 3.1, 1.5, 0.2, 0. ],
       [5. , 3.6, 1.4, 0.2, 0. ],
       [5.4, 3.9, 1.7, 0.4, 0. ],
       [4.6, 3.4, 1.4, 0.3, 0. ],
       [5. , 3.4, 1.5, 0.2, 0. ],
       [4.4, 2.9, 1.4, 0.2, 0. ],
       [4.9, 3.1, 1.5, 0.1, 0. ],
       [5.4, 3.7, 1.5, 0.2, 0. ],
       [4.8, 3.4, 1.6, 0.2, 0. ],
       [4.8, 3. , 1.4, 0.1, 0. ],
       [4.3, 3. , 1.1, 0.1, 0. ],
       [5.8, 4. , 1.2, 0.2, 0. ],
       [5.7, 4.4, 1.5, 0.4, 0. ],
       [5.4, 3.9, 1.3, 0.4, 0. ],
       [5.1, 3.5, 1.4, 0.3, 0. ],
       [5.7, 3.8, 1.7, 0.3, 0. ],
       [5.1, 3.8, 1.5, 0.3, 0. ],
       [5.4, 3.4, 1.7, 0.2, 0. ],
       [5.1, 3.7, 1.5, 0.4, 0. ],
       [4.6, 3.6, 1. , 0.2, 0. ],
       [5.1, 3.3, 1.7, 0.5, 0. ],
       [4.8, 3.4, 1.9, 0.2, 0. ],
       [5. , 3. , 1.6, 0.2, 0. ],
       [5. , 3.4, 1.6, 0.4, 0. ],
       [5.2, 3.5, 1.5, 0.2, 0. ],
       [5.2, 3.4, 1.4, 0.2, 0. ],
       [4.7, 3.2, 1.6, 0.2, 0. ],
       [4.8, 3.1, 1.6, 0.2, 0. ],
       [5.4, 3.4, 1.5, 0.4, 0. ],
       [5.2, 4.1, 1.5, 0.1, 0. ],
       [5.5, 4.2, 1.4, 0.2, 0. ],
       [4.9, 3.1, 1.5, 0.1, 0. ],
       [5. , 3.2, 1.2, 0.2, 0. ],
       [5.5, 3.5, 1.3, 0.2, 0. ],
       [4.9, 3.1, 1.5, 0.1, 0. ],
       [4.4, 3. , 1.3, 0.2, 0. ],
       [5.1, 3.4, 1.5, 0.2, 0. ],
       [5. , 3.5, 1.3, 0.3, 0. ],
       [4.5, 2.3, 1.3, 0.3, 0. ],
       [4.4, 3.2, 1.3, 0.2, 0. ],
       [5. , 3.5, 1.6, 0.6, 0. ],
       [5.1, 3.8, 1.9, 0.4, 0. ],
       [4.8, 3. , 1.4, 0.3, 0. ],
       [5.1, 3.8, 1.6, 0.2, 0. ],
       [4.6, 3.2, 1.4, 0.2, 0. ],
       [5.3, 3.7, 1.5, 0.2, 0. ],
       [5. , 3.3, 1.4, 0.2, 0. ],
       [7. , 3.2, 4.7, 1.4, 1. ],
       [6.4, 3.2, 4.5, 1.5, 1. ],
       [6.9, 3.1, 4.9, 1.5, 1. ],
       [5.5, 2.3, 4. , 1.3, 1. ],
       [6.5, 2.8, 4.6, 1.5, 1. ],
       [5.7, 2.8, 4.5, 1.3, 1. ],
       [6.3, 3.3, 4.7, 1.6, 1. ],
       [4.9, 2.4, 3.3, 1. , 1. ],
       [6.6, 2.9, 4.6, 1.3, 1. ],
       [5.2, 2.7, 3.9, 1.4, 1. ],
       [5. , 2. , 3.5, 1. , 1. ],
       [5.9, 3. , 4.2, 1.5, 1. ],
       [6. , 2.2, 4. , 1. , 1. ],
       [6.1, 2.9, 4.7, 1.4, 1. ],
       [5.6, 2.9, 3.6, 1.3, 1. ],
       [6.7, 3.1, 4.4, 1.4, 1. ],
       [5.6, 3. , 4.5, 1.5, 1. ],
       [5.8, 2.7, 4.1, 1. , 1. ],
       [6.2, 2.2, 4.5, 1.5, 1. ],
       [5.6, 2.5, 3.9, 1.1, 1. ],
       [5.9, 3.2, 4.8, 1.8, 1. ],
       [6.1, 2.8, 4. , 1.3, 1. ],
       [6.3, 2.5, 4.9, 1.5, 1. ],
       [6.1, 2.8, 4.7, 1.2, 1. ],
       [6.4, 2.9, 4.3, 1.3, 1. ],
       [6.6, 3. , 4.4, 1.4, 1. ],
       [6.8, 2.8, 4.8, 1.4, 1. ],
       [6.7, 3. , 5. , 1.7, 1. ],
       [6. , 2.9, 4.5, 1.5, 1. ],
       [5.7, 2.6, 3.5, 1. , 1. ],
       [5.5, 2.4, 3.8, 1.1, 1. ],
       [5.5, 2.4, 3.7, 1. , 1. ],
       [5.8, 2.7, 3.9, 1.2, 1. ],
       [6. , 2.7, 5.1, 1.6, 1. ],
       [5.4, 3. , 4.5, 1.5, 1. ],
       [6. , 3.4, 4.5, 1.6, 1. ],
       [6.7, 3.1, 4.7, 1.5, 1. ],
       [6.3, 2.3, 4.4, 1.3, 1. ],
       [5.6, 3. , 4.1, 1.3, 1. ],
       [5.5, 2.5, 4. , 1.3, 1. ],
       [5.5, 2.6, 4.4, 1.2, 1. ],
       [6.1, 3. , 4.6, 1.4, 1. ],
       [5.8, 2.6, 4. , 1.2, 1. ],
       [5. , 2.3, 3.3, 1. , 1. ],
       [5.6, 2.7, 4.2, 1.3, 1. ],
       [5.7, 3. , 4.2, 1.2, 1. ],
       [5.7, 2.9, 4.2, 1.3, 1. ],
       [6.2, 2.9, 4.3, 1.3, 1. ],
       [5.1, 2.5, 3. , 1.1, 1. ],
       [5.7, 2.8, 4.1, 1.3, 1. ],
       [6.3, 3.3, 6. , 2.5, 2. ],
       [5.8, 2.7, 5.1, 1.9, 2. ],
       [7.1, 3. , 5.9, 2.1, 2. ],
       [6.3, 2.9, 5.6, 1.8, 2. ],
       [6.5, 3. , 5.8, 2.2, 2. ],
       [7.6, 3. , 6.6, 2.1, 2. ],
       [4.9, 2.5, 4.5, 1.7, 2. ],
       [7.3, 2.9, 6.3, 1.8, 2. ],
       [6.7, 2.5, 5.8, 1.8, 2. ],
       [7.2, 3.6, 6.1, 2.5, 2. ],
       [6.5, 3.2, 5.1, 2. , 2. ],
       [6.4, 2.7, 5.3, 1.9, 2. ],
       [6.8, 3. , 5.5, 2.1, 2. ],
       [5.7, 2.5, 5. , 2. , 2. ],
       [5.8, 2.8, 5.1, 2.4, 2. ],
       [6.4, 3.2, 5.3, 2.3, 2. ],
       [6.5, 3. , 5.5, 1.8, 2. ],
       [7.7, 3.8, 6.7, 2.2, 2. ],
       [7.7, 2.6, 6.9, 2.3, 2. ],
       [6. , 2.2, 5. , 1.5, 2. ],
       [6.9, 3.2, 5.7, 2.3, 2. ],
       [5.6, 2.8, 4.9, 2. , 2. ],
       [7.7, 2.8, 6.7, 2. , 2. ],
       [6.3, 2.7, 4.9, 1.8, 2. ],
       [6.7, 3.3, 5.7, 2.1, 2. ],
       [7.2, 3.2, 6. , 1.8, 2. ],
       [6.2, 2.8, 4.8, 1.8, 2. ],
       [6.1, 3. , 4.9, 1.8, 2. ],
       [6.4, 2.8, 5.6, 2.1, 2. ],
       [7.2, 3. , 5.8, 1.6, 2. ],
       [7.4, 2.8, 6.1, 1.9, 2. ],
       [7.9, 3.8, 6.4, 2. , 2. ],
       [6.4, 2.8, 5.6, 2.2, 2. ],
       [6.3, 2.8, 5.1, 1.5, 2. ],
       [6.1, 2.6, 5.6, 1.4, 2. ],
       [7.7, 3. , 6.1, 2.3, 2. ],
       [6.3, 3.4, 5.6, 2.4, 2. ],
       [6.4, 3.1, 5.5, 1.8, 2. ],
       [6. , 3. , 4.8, 1.8, 2. ],
       [6.9, 3.1, 5.4, 2.1, 2. ],
       [6.7, 3.1, 5.6, 2.4, 2. ],
       [6.9, 3.1, 5.1, 2.3, 2. ],
       [5.8, 2.7, 5.1, 1.9, 2. ],
       [6.8, 3.2, 5.9, 2.3, 2. ],
       [6.7, 3.3, 5.7, 2.5, 2. ],
       [6.7, 3. , 5.2, 2.3, 2. ],
       [6.3, 2.5, 5. , 1.9, 2. ],
       [6.5, 3. , 5.2, 2. , 2. ],
       [6.2, 3.4, 5.4, 2.3, 2. ],
       [5.9, 3. , 5.1, 1.8, 2. ]])
In [16]:
x, y = np.split(data, (4,), axis=1)  # 前四列是x最后一列是y
x
Out[16]:
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]])
In [17]:
y
Out[17]:
array([[0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.],
       [2.]])
In [18]:
x = x[:, :2]  # 只取前两列作为x
x
Out[18]:
array([[5.1, 3.5],
       [4.9, 3. ],
       [4.7, 3.2],
       [4.6, 3.1],
       [5. , 3.6],
       [5.4, 3.9],
       [4.6, 3.4],
       [5. , 3.4],
       [4.4, 2.9],
       [4.9, 3.1],
       [5.4, 3.7],
       [4.8, 3.4],
       [4.8, 3. ],
       [4.3, 3. ],
       [5.8, 4. ],
       [5.7, 4.4],
       [5.4, 3.9],
       [5.1, 3.5],
       [5.7, 3.8],
       [5.1, 3.8],
       [5.4, 3.4],
       [5.1, 3.7],
       [4.6, 3.6],
       [5.1, 3.3],
       [4.8, 3.4],
       [5. , 3. ],
       [5. , 3.4],
       [5.2, 3.5],
       [5.2, 3.4],
       [4.7, 3.2],
       [4.8, 3.1],
       [5.4, 3.4],
       [5.2, 4.1],
       [5.5, 4.2],
       [4.9, 3.1],
       [5. , 3.2],
       [5.5, 3.5],
       [4.9, 3.1],
       [4.4, 3. ],
       [5.1, 3.4],
       [5. , 3.5],
       [4.5, 2.3],
       [4.4, 3.2],
       [5. , 3.5],
       [5.1, 3.8],
       [4.8, 3. ],
       [5.1, 3.8],
       [4.6, 3.2],
       [5.3, 3.7],
       [5. , 3.3],
       [7. , 3.2],
       [6.4, 3.2],
       [6.9, 3.1],
       [5.5, 2.3],
       [6.5, 2.8],
       [5.7, 2.8],
       [6.3, 3.3],
       [4.9, 2.4],
       [6.6, 2.9],
       [5.2, 2.7],
       [5. , 2. ],
       [5.9, 3. ],
       [6. , 2.2],
       [6.1, 2.9],
       [5.6, 2.9],
       [6.7, 3.1],
       [5.6, 3. ],
       [5.8, 2.7],
       [6.2, 2.2],
       [5.6, 2.5],
       [5.9, 3.2],
       [6.1, 2.8],
       [6.3, 2.5],
       [6.1, 2.8],
       [6.4, 2.9],
       [6.6, 3. ],
       [6.8, 2.8],
       [6.7, 3. ],
       [6. , 2.9],
       [5.7, 2.6],
       [5.5, 2.4],
       [5.5, 2.4],
       [5.8, 2.7],
       [6. , 2.7],
       [5.4, 3. ],
       [6. , 3.4],
       [6.7, 3.1],
       [6.3, 2.3],
       [5.6, 3. ],
       [5.5, 2.5],
       [5.5, 2.6],
       [6.1, 3. ],
       [5.8, 2.6],
       [5. , 2.3],
       [5.6, 2.7],
       [5.7, 3. ],
       [5.7, 2.9],
       [6.2, 2.9],
       [5.1, 2.5],
       [5.7, 2.8],
       [6.3, 3.3],
       [5.8, 2.7],
       [7.1, 3. ],
       [6.3, 2.9],
       [6.5, 3. ],
       [7.6, 3. ],
       [4.9, 2.5],
       [7.3, 2.9],
       [6.7, 2.5],
       [7.2, 3.6],
       [6.5, 3.2],
       [6.4, 2.7],
       [6.8, 3. ],
       [5.7, 2.5],
       [5.8, 2.8],
       [6.4, 3.2],
       [6.5, 3. ],
       [7.7, 3.8],
       [7.7, 2.6],
       [6. , 2.2],
       [6.9, 3.2],
       [5.6, 2.8],
       [7.7, 2.8],
       [6.3, 2.7],
       [6.7, 3.3],
       [7.2, 3.2],
       [6.2, 2.8],
       [6.1, 3. ],
       [6.4, 2.8],
       [7.2, 3. ],
       [7.4, 2.8],
       [7.9, 3.8],
       [6.4, 2.8],
       [6.3, 2.8],
       [6.1, 2.6],
       [7.7, 3. ],
       [6.3, 3.4],
       [6.4, 3.1],
       [6. , 3. ],
       [6.9, 3.1],
       [6.7, 3.1],
       [6.9, 3.1],
       [5.8, 2.7],
       [6.8, 3.2],
       [6.7, 3.3],
       [6.7, 3. ],
       [6.3, 2.5],
       [6.5, 3. ],
       [6.2, 3.4],
       [5.9, 3. ]])
In [19]:
gnb = Pipeline([
    ('sc', StandardScaler()),  # 把数据进行高斯标准化,以0为均值,1为方差
    ('clf', GaussianNB())])  # 假定数据为高斯分布
In [20]:
y.ravel()  # 转化为行向量
Out[20]:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])
In [21]:
gnb.fit(x, y.ravel())
Out[21]:
Pipeline(memory=None,
     steps=[('sc', StandardScaler(copy=True, with_mean=True, with_std=True)), ('clf', GaussianNB(priors=None, var_smoothing=1e-09))])
In [23]:
y_hat = gnb.predict(x)
y_hat
Out[23]:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 2.,
       2., 2., 1., 2., 1., 2., 1., 2., 1., 1., 1., 1., 1., 1., 2., 1., 1.,
       1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1.,
       2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1.,
       2., 1., 2., 2., 1., 2., 2., 2., 2., 1., 2., 1., 1., 2., 2., 2., 2.,
       1., 2., 1., 2., 1., 2., 2., 1., 1., 1., 2., 2., 2., 1., 1., 1., 2.,
       2., 2., 1., 2., 2., 2., 1., 2., 2., 2., 1., 2., 2., 1.])
In [24]:
y = y.reshape(-1)  # 相当于y.ravel()
y
Out[24]:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])
In [25]:
result = y_hat == y
result
Out[25]:
array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True, False,  True,  True,  True,
        True,  True,  True,  True,  True, False, False, False,  True,
       False,  True, False,  True, False,  True,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True,  True,  True,
        True,  True, False, False, False, False,  True,  True,  True,
        True,  True,  True,  True, False, False,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True, False,  True, False,  True,  True, False,  True,
        True,  True,  True, False,  True, False, False,  True,  True,
        True,  True, False,  True, False,  True, False,  True,  True,
       False, False, False,  True,  True,  True, False, False, False,
        True,  True,  True, False,  True,  True,  True, False,  True,
        True,  True, False,  True,  True, False])
In [27]:
acc = np.mean(result)  # 相当于把true当成1,false为0,求平均值,即为准确率
acc
Out[27]:
0.78
 

以下为版本2

In [29]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScaler, MinMaxScaler, PolynomialFeatures
from sklearn.naive_bayes import GaussianNB, MultinomialNB  #高斯贝叶斯和多项式朴素贝叶斯
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

# 设置属性防止中文乱码
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False

# 花萼长度、花萼宽度,花瓣长度,花瓣宽度
iris_feature_E = 'sepal length', 'sepal width', 'petal length', 'petal width'
iris_feature_C = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'
iris_class = 'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'
features = [2, 3]

# 读取数据
path = 'D:\\mlInAction\\8.iris.data'  # 数据文件路径
data = pd.read_csv(path, header=None)
data
Out[29]:
 
 01234
0 5.1 3.5 1.4 0.2 Iris-setosa
1 4.9 3.0 1.4 0.2 Iris-setosa
2 4.7 3.2 1.3 0.2 Iris-setosa
3 4.6 3.1 1.5 0.2 Iris-setosa
4 5.0 3.6 1.4 0.2 Iris-setosa
5 5.4 3.9 1.7 0.4 Iris-setosa
6 4.6 3.4 1.4 0.3 Iris-setosa
7 5.0 3.4 1.5 0.2 Iris-setosa
8 4.4 2.9 1.4 0.2 Iris-setosa
9 4.9 3.1 1.5 0.1 Iris-setosa
10 5.4 3.7 1.5 0.2 Iris-setosa
11 4.8 3.4 1.6 0.2 Iris-setosa
12 4.8 3.0 1.4 0.1 Iris-setosa
13 4.3 3.0 1.1 0.1 Iris-setosa
14 5.8 4.0 1.2 0.2 Iris-setosa
15 5.7 4.4 1.5 0.4 Iris-setosa
16 5.4 3.9 1.3 0.4 Iris-setosa
17 5.1 3.5 1.4 0.3 Iris-setosa
18 5.7 3.8 1.7 0.3 Iris-setosa
19 5.1 3.8 1.5 0.3 Iris-setosa
20 5.4 3.4 1.7 0.2 Iris-setosa
21 5.1 3.7 1.5 0.4 Iris-setosa
22 4.6 3.6 1.0 0.2 Iris-setosa
23 5.1 3.3 1.7 0.5 Iris-setosa
24 4.8 3.4 1.9 0.2 Iris-setosa
25 5.0 3.0 1.6 0.2 Iris-setosa
26 5.0 3.4 1.6 0.4 Iris-setosa
27 5.2 3.5 1.5 0.2 Iris-setosa
28 5.2 3.4 1.4 0.2 Iris-setosa
29 4.7 3.2 1.6 0.2 Iris-setosa
... ... ... ... ... ...
120 6.9 3.2 5.7 2.3 Iris-virginica
121 5.6 2.8 4.9 2.0 Iris-virginica
122 7.7 2.8 6.7 2.0 Iris-virginica
123 6.3 2.7 4.9 1.8 Iris-virginica
124 6.7 3.3 5.7 2.1 Iris-virginica
125 7.2 3.2 6.0 1.8 Iris-virginica
126 6.2 2.8 4.8 1.8 Iris-virginica
127 6.1 3.0 4.9 1.8 Iris-virginica
128 6.4 2.8 5.6 2.1 Iris-virginica
129 7.2 3.0 5.8 1.6 Iris-virginica
130 7.4 2.8 6.1 1.9 Iris-virginica
131 7.9 3.8 6.4 2.0 Iris-virginica
132 6.4 2.8 5.6 2.2 Iris-virginica
133 6.3 2.8 5.1 1.5 Iris-virginica
134 6.1 2.6 5.6 1.4 Iris-virginica
135 7.7 3.0 6.1 2.3 Iris-virginica
136 6.3 3.4 5.6 2.4 Iris-virginica
137 6.4 3.1 5.5 1.8 Iris-virginica
138 6.0 3.0 4.8 1.8 Iris-virginica
139 6.9 3.1 5.4 2.1 Iris-virginica
140 6.7 3.1 5.6 2.4 Iris-virginica
141 6.9 3.1 5.1 2.3 Iris-virginica
142 5.8 2.7 5.1 1.9 Iris-virginica
143 6.8 3.2 5.9 2.3 Iris-virginica
144 6.7 3.3 5.7 2.5 Iris-virginica
145 6.7 3.0 5.2 2.3 Iris-virginica
146 6.3 2.5 5.0 1.9 Iris-virginica
147 6.5 3.0 5.2 2.0 Iris-virginica
148 6.2 3.4 5.4 2.3 Iris-virginica
149 5.9 3.0 5.1 1.8 Iris-virginica

150 rows × 5 columns

In [35]:
x = data[list(range(4))]  # 此处为pd,不能用切片
x
Out[35]:
 
 0123
0 5.1 3.5 1.4 0.2
1 4.9 3.0 1.4 0.2
2 4.7 3.2 1.3 0.2
3 4.6 3.1 1.5 0.2
4 5.0 3.6 1.4 0.2
5 5.4 3.9 1.7 0.4
6 4.6 3.4 1.4 0.3
7 5.0 3.4 1.5 0.2
8 4.4 2.9 1.4 0.2
9 4.9 3.1 1.5 0.1
10 5.4 3.7 1.5 0.2
11 4.8 3.4 1.6 0.2
12 4.8 3.0 1.4 0.1
13 4.3 3.0 1.1 0.1
14 5.8 4.0 1.2 0.2
15 5.7 4.4 1.5 0.4
16 5.4 3.9 1.3 0.4
17 5.1 3.5 1.4 0.3
18 5.7 3.8 1.7 0.3
19 5.1 3.8 1.5 0.3
20 5.4 3.4 1.7 0.2
21 5.1 3.7 1.5 0.4
22 4.6 3.6 1.0 0.2
23 5.1 3.3 1.7 0.5
24 4.8 3.4 1.9 0.2
25 5.0 3.0 1.6 0.2
26 5.0 3.4 1.6 0.4
27 5.2 3.5 1.5 0.2
28 5.2 3.4 1.4 0.2
29 4.7 3.2 1.6 0.2
... ... ... ... ...
120 6.9 3.2 5.7 2.3
121 5.6 2.8 4.9 2.0
122 7.7 2.8 6.7 2.0
123 6.3 2.7 4.9 1.8
124 6.7 3.3 5.7 2.1
125 7.2 3.2 6.0 1.8
126 6.2 2.8 4.8 1.8
127 6.1 3.0 4.9 1.8
128 6.4 2.8 5.6 2.1
129 7.2 3.0 5.8 1.6
130 7.4 2.8 6.1 1.9
131 7.9 3.8 6.4 2.0
132 6.4 2.8 5.6 2.2
133 6.3 2.8 5.1 1.5
134 6.1 2.6 5.6 1.4
135 7.7 3.0 6.1 2.3
136 6.3 3.4 5.6 2.4
137 6.4 3.1 5.5 1.8
138 6.0 3.0 4.8 1.8
139 6.9 3.1 5.4 2.1
140 6.7 3.1 5.6 2.4
141 6.9 3.1 5.1 2.3
142 5.8 2.7 5.1 1.9
143 6.8 3.2 5.9 2.3
144 6.7 3.3 5.7 2.5
145 6.7 3.0 5.2 2.3
146 6.3 2.5 5.0 1.9
147 6.5 3.0 5.2 2.0
148 6.2 3.4 5.4 2.3
149 5.9 3.0 5.1 1.8

150 rows × 4 columns

In [36]:
x = x[features]
x
Out[36]:
 
 23
0 1.4 0.2
1 1.4 0.2
2 1.3 0.2
3 1.5 0.2
4 1.4 0.2
5 1.7 0.4
6 1.4 0.3
7 1.5 0.2
8 1.4 0.2
9 1.5 0.1
10 1.5 0.2
11 1.6 0.2
12 1.4 0.1
13 1.1 0.1
14 1.2 0.2
15 1.5 0.4
16 1.3 0.4
17 1.4 0.3
18 1.7 0.3
19 1.5 0.3
20 1.7 0.2
21 1.5 0.4
22 1.0 0.2
23 1.7 0.5
24 1.9 0.2
25 1.6 0.2
26 1.6 0.4
27 1.5 0.2
28 1.4 0.2
29 1.6 0.2
... ... ...
120 5.7 2.3
121 4.9 2.0
122 6.7 2.0
123 4.9 1.8
124 5.7 2.1
125 6.0 1.8
126 4.8 1.8
127 4.9 1.8
128 5.6 2.1
129 5.8 1.6
130 6.1 1.9
131 6.4 2.0
132 5.6 2.2
133 5.1 1.5
134 5.6 1.4
135 6.1 2.3
136 5.6 2.4
137 5.5 1.8
138 4.8 1.8
139 5.4 2.1
140 5.6 2.4
141 5.1 2.3
142 5.1 1.9
143 5.9 2.3
144 5.7 2.5
145 5.2 2.3
146 5.0 1.9
147 5.2 2.0
148 5.4 2.3
149 5.1 1.8

150 rows × 2 columns

In [37]:
y = pd.Categorical(data[4]).codes  # 直接将数据特征转换为0,1,2
y
Out[37]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int8)
In [38]:
print("总样本数目:%d;特征属性数目:%d" % x.shape)
 
总样本数目:150;特征属性数目:2
In [40]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=14)
print("训练数据集样本数目:%d, 测试数据集样本数目:%d" % (x_train.shape[0], x_test.shape[0]))
 
训练数据集样本数目:120, 测试数据集样本数目:30
In [41]:
clf = Pipeline([
    ('sc', StandardScaler()),  # 标准化,把它转化成了高斯分布
    ('poly', PolynomialFeatures(degree=1)),
    ('clf', GaussianNB())])  # MultinomialNB多项式贝叶斯算法中要求特征属性的取值不能为负数
# 训练模型
clf.fit(x_train, y_train)
Out[41]:
Pipeline(memory=None,
     steps=[('sc', StandardScaler(copy=True, with_mean=True, with_std=True)), ('poly', PolynomialFeatures(degree=1, include_bias=True, interaction_only=False)), ('clf', GaussianNB(priors=None, var_smoothing=1e-09))])
In [42]:
y_train_hat = clf.predict(x_train)
print('训练集准确度: %.2f%%' % (100 * accuracy_score(y_train, y_train_hat)))
y_test_hat = clf.predict(x_test)
print('测试集准确度:%.2f%%' % (100 * accuracy_score(y_test, y_test_hat)))
 
训练集准确度: 95.83%
测试集准确度:96.67%