0x02 机器学习-交叉检验(附带保存训练好的模型)

01机器学习可能出现的误差

  1.过拟合 就是机器读死书,死记硬背,对作业掌握的不错,但是测试时就很糟糕

  2.欠拟合 就是机器不好好学习,导致结果不令人满意

02交叉检验-1

  平均得分和多次循环求最优解

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

iris = load_iris()
X = iris.data
y = iris.target
''' 级别一 # TTTTt T表示训练数据 t表示测试数据 X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=4) knn = KNeighborsClassifier(n_neighbors=5) #取5个点变成一个点 knn.fit(X_train,y_train) print(knn.score(X_test,y_test)) ''' ''' 级别二 重复5次u取平均值 TTTTt TTTtT TTtTT TtTTT tTTTT from sklearn.model_selection import cross_val_score knn = KNeighborsClassifier(n_neighbors=5) scores = cross_val_score(knn,X,y,cv=5,scoring='accuracy') print(scores.mean()) ''' from sklearn.model_selection import cross_val_score import matplotlib.pyplot as plt k_range = range(1,31) k_scores = [] # 通过循环不断的寻找 n_neighbors参数的比较好的解 for k in k_range: knn = KNeighborsClassifier(n_neighbors=k) # loss = -cross_val_score(knn,X,y,cv=10,scoring='neg_mean_squared_error') # 注意最前面的负号 一个是误差曲线 下一个是准确度曲线 scores = cross_val_score(knn,X,y,cv=10,scoring='accuracy') k_scores.append(scores.mean()) # 对结果做可视化 plt.plot(k_range,k_scores) plt.xlabel('Value of K for Knn') plt.ylabel('Cross-Validated Accuracy') plt.show()

02交叉检验-2

   对结果一步一步的可视化

  使用learning_curve分段评估结果的好坏

from sklearn.model_selection import learning_curve  # 这个是调参用的 评价曲线
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np

digis = load_digits()
X = digis.data
y = digis.target

# 进行训练
# --------
# 隔多久记录一次 训练的损失值 测试的损失值
train_sizes,train_loss,test_loss = learning_curve(
    SVC(gamma=0.001),X,y,cv=10,scoring='neg_mean_squared_error',
   # 用到的方法  数据  几次交叉验证  评分的标准
    train_sizes=[0.1,0.25,0.50,0.75,1])  # 10% 25% 50% 在学习过程中的5个点
# --------
# train_loss
误差的形式 cv1 cv2 cv3 cv4 cv5 cv6 cv7 cv8 cv9 cv10
10%                    
25%                    
50%                    
75%                    
100%                    





train_loss_mean
= -np.mean(train_loss,axis=1) # 按y轴滑动求平均值 test_loss_mean = -np.mean(test_loss,axis=1) # 按y轴滑动求平均值 # 可视化 # 红色的线表示,训练完做作业的结果 # 绿色的线表示训练完考试的结果 plt.plot(train_sizes,train_loss_mean,'o-',color='r', label='trainning') plt.plot(train_sizes,test_loss_mean,'o-',color='g', label='cross-validation') plt.xlabel('trainning examples') plt.ylabel('loss') plt.legend(loc='best') plt.show()

03交叉验证-3

  使用validation_curve 来调节参数

 1 from sklearn.model_selection import validation_curve # 用来调模型的参数
 2 from sklearn.datasets import load_digits
 3 from sklearn.svm import SVC
 4 import matplotlib.pyplot as plt
 5 import numpy as np
 6 
 7 digis = load_digits()
 8 X = digis.data
 9 y = digis.target
10 # 在10**-6 到 10**-2.3 之间取5个点
11 param_range = np.logspace(-6,-2.3,5)
12 # 进行训练
13 # --------
14 # 训练的损失值 测试的损失值
15 train_loss,test_loss = validation_curve(
16     SVC(),X,y,param_name='gamma',param_range=param_range,
17    # 用到的方法  数据      要调动的参数      调动参数的范围
18     scoring='neg_mean_squared_error'
19 )  # 评分的标准
20 # --------
21
gamma cv1 cv2 cv3
10-6      
....      
.....      
.....      
10-2.3      






22
train_loss_mean = -np.mean(train_loss,axis=1) # 按y轴滑动求平均值 23 test_loss_mean = -np.mean(test_loss,axis=1) # 按y轴滑动求平均值 24 # 可视化 25 # 红色的线表示,训练完做作业的结果 26 # 绿色的线表示训练完考试的结果 27 plt.plot(param_range,train_loss_mean,'o-',color='r', 28 label='trainning') 29 plt.plot(param_range,test_loss_mean,'o-',color='g', 30 label='cross-validation') 31 plt.xlabel('trainning examples') 32 plt.ylabel('loss') 33 plt.legend(loc='best') 34 plt.show()

04附带如何保存训练的结果

 

from sklearn import svm
from sklearn import datasets

clf = svm.SVC()
iris = datasets.load_iris()
X,y = iris.data,iris.target
# 训练数据
clf.fit(X,y)

# 法一 当训练出的模型小 : pickle
import pickle
with open('save/clf.pickle','wb') as f:
    pickle.dump(clf,f)
with open('save/clf.pickle','rb') as f:
    clf2 = pickle.load(f)
    print(clf2.predict(X[0:3]))

# 法二 当训练出来的模型大 :joblib #使用多进程存储文件
from sklearn.externals import joblib
# Save
joblib.dump(clf,'save/clf.pkl')
# Restore
clf3 = joblib.load('save/clf.pkl')
print(clf3.predict(X[12:15]))

 

posted @ 2019-06-10 14:13  childhood_2  阅读(606)  评论(0编辑  收藏  举报