博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

02scikit-learn模型训练

Posted on 2019-02-20 19:48  心默默言  阅读(211)  评论(0编辑  收藏  举报

模型训练

In [6]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.datasets import load_boston

data = load_boston()
clf = LinearRegression()
n_samples, n_features = data.data.shape
n_samples, n_features
Out[6]:
(506, 13)
In [11]:
data.keys()
Out[11]:
dict_keys(['data', 'target', 'feature_names', 'DESCR', 'filename'])
In [12]:
data.feature_names
Out[12]:
array(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT'], dtype='<U7')
In [15]:
# play with features
column_i = 5
plt.scatter(data.data[:, column_i], data.target)
data.feature_names[5]  # room
Out[15]:
'RM'
 
In [16]:
from sklearn.metrics import mean_absolute_error

clf.fit(data.data, data.target)
predicted = clf.predict(data.data)
mean_absolute_error(data.target, predicted)
Out[16]:
3.270862810900314
In [17]:
plt.scatter(data.target, predicted)
plt.xlabel('true_price')
plt.ylabel('predict_price')
plt.plot(data.target, data.target, color='red')
Out[17]:
[<matplotlib.lines.Line2D at 0x11c81eb8>]
 
In [20]:
# try another non_linear model
from sklearn.tree import DecisionTreeRegressor

clf2 = DecisionTreeRegressor()
clf2.fit(data.data, data.target)
predicted2 = clf2.predict(data.data)
mean_absolute_error(data.target, predicted2)
Out[20]:
0.0
In [21]:
plt.scatter(data.target, predicted2)
plt.xlabel('true_price')
plt.ylabel('predict_price')
plt.plot(data.target, data.target, color='red')
Out[21]:
[<matplotlib.lines.Line2D at 0x11d292e8>]
 
 

上图训练的非常好,可能会产生过拟合

In [25]:
# practice classification model
# example Logistic Regression and probability prediction
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression

iris = load_iris()
clf = LogisticRegression(solver='liblinear', multi_class='auto')
clf.fit(iris.data, iris.target)
probability = clf.predict_proba(iris.data)  # 返回预测属于某标签的概率
probability  # 例如下面第一行,有87.8%的概率是属于第一类的,有12.2%的概率是属于第
# 第二类的,依次类推
Out[25]:
array([[0.87803031, 0.1219589 , 0.00001079],
       [0.79705829, 0.20291141, 0.00003029],
       [0.85199767, 0.14797648, 0.00002586],
       [0.82340602, 0.17653616, 0.00005782],
       [0.89603497, 0.10395384, 0.00001119],
       [0.92623425, 0.07375278, 0.00001296],
       [0.89409685, 0.10586394, 0.00003922],
       [0.86003441, 0.13994671, 0.00001888],
       [0.80102864, 0.19888675, 0.0000846 ],
       [0.79266239, 0.207312  , 0.00002561],
       [0.89048611, 0.10950773, 0.00000616],
       [0.86180067, 0.13816496, 0.00003437],
       [0.78536437, 0.21460826, 0.00002737],
       [0.83312233, 0.1668456 , 0.00003207],
       [0.92710508, 0.07289396, 0.00000097],
       [0.96420978, 0.03578796, 0.00000226],
       [0.94024468, 0.05975048, 0.00000484],
       [0.89038364, 0.1096022 , 0.00001416],
       [0.89499643, 0.1049968 , 0.00000677],
       [0.92281833, 0.07716985, 0.00001182],
       [0.82816884, 0.17181599, 0.00001517],
       [0.9211629 , 0.07881927, 0.00001783],
       [0.92583055, 0.07416099, 0.00000846],
       [0.86642505, 0.13350683, 0.00006812],
       [0.83957906, 0.16034845, 0.00007249],
       [0.77438785, 0.22557074, 0.00004141],
       [0.88014221, 0.11981603, 0.00004176],
       [0.86814212, 0.13184633, 0.00001156],
       [0.85798154, 0.14200808, 0.00001039],
       [0.83013655, 0.16980939, 0.00005406],
       [0.80548889, 0.1944593 , 0.00005181],
       [0.87080741, 0.12917648, 0.00001611],
       [0.9331403 , 0.06685591, 0.00000379],
       [0.94556305, 0.05443497, 0.00000198],
       [0.8091041 , 0.19086202, 0.00003387],
       [0.84540667, 0.15458141, 0.00001192],
       [0.8678451 , 0.13215071, 0.00000419],
       [0.88781581, 0.11217401, 0.00001018],
       [0.82917332, 0.17076892, 0.00005776],
       [0.85578733, 0.14419685, 0.00001582],
       [0.89902143, 0.10096537, 0.00001321],
       [0.68760966, 0.31222729, 0.00016305],
       [0.86468741, 0.13526866, 0.00004393],
       [0.91572506, 0.08421267, 0.00006227],
       [0.91483865, 0.08511969, 0.00004165],
       [0.81813982, 0.18181229, 0.00004789],
       [0.90880255, 0.09118589, 0.00001156],
       [0.84952236, 0.15043821, 0.00003943],
       [0.89405093, 0.10594173, 0.00000734],
       [0.84954119, 0.15044186, 0.00001695],
       [0.02960053, 0.86126971, 0.10912976],
       [0.03735662, 0.70599864, 0.25664474],
       [0.01171882, 0.74918029, 0.23910089],
       [0.01323293, 0.65261527, 0.3341518 ],
       [0.0109261 , 0.69975168, 0.28932221],
       [0.01074757, 0.58352345, 0.40572897],
       [0.02155363, 0.53736735, 0.44107901],
       [0.10779613, 0.76976403, 0.12243984],
       [0.01756481, 0.82847163, 0.15396355],
       [0.0331001 , 0.52843482, 0.43846508],
       [0.02909409, 0.77362145, 0.19728446],
       [0.04095529, 0.61978222, 0.33926249],
       [0.01930641, 0.88041176, 0.10028182],
       [0.00871112, 0.59711982, 0.39416907],
       [0.16693542, 0.7134194 , 0.11964518],
       [0.0471498 , 0.84453959, 0.1083106 ],
       [0.01229146, 0.42322112, 0.56448741],
       [0.03811694, 0.85107181, 0.11081125],
       [0.00308283, 0.59723043, 0.39968674],
       [0.03569679, 0.80966589, 0.15463732],
       [0.00624631, 0.27162577, 0.72212792],
       [0.05767621, 0.82064253, 0.12168126],
       [0.00195123, 0.53464684, 0.46340193],
       [0.0087242 , 0.70558697, 0.28568882],
       [0.03660593, 0.83989907, 0.123495  ],
       [0.0358882 , 0.82963776, 0.13447404],
       [0.00807402, 0.77816015, 0.21376583],
       [0.00463307, 0.52364059, 0.47172634],
       [0.01333998, 0.56347986, 0.42318016],
       [0.12711691, 0.8329313 , 0.03995179],
       [0.03581044, 0.80413792, 0.16005163],
       [0.05003383, 0.84711273, 0.10285344],
       [0.05656025, 0.81218015, 0.1312596 ],
       [0.001226  , 0.39930356, 0.59947044],
       [0.01035901, 0.36404062, 0.62560038],
       [0.04192755, 0.47659618, 0.48147627],
       [0.01894857, 0.74644236, 0.23460907],
       [0.00699118, 0.75788979, 0.23511903],
       [0.05570461, 0.66760154, 0.27669385],
       [0.0210041 , 0.6630778 , 0.3159181 ],
       [0.00895359, 0.60041736, 0.39062905],
       [0.01518493, 0.63284951, 0.35196556],
       [0.03451475, 0.79953197, 0.16595328],
       [0.09088978, 0.79694616, 0.11216406],
       [0.01979054, 0.64145332, 0.33875613],
       [0.0479463 , 0.73160741, 0.22044629],
       [0.03437866, 0.67792614, 0.2876952 ],
       [0.03365277, 0.79775962, 0.16858761],
       [0.25317619, 0.69233045, 0.05449335],
       [0.03622935, 0.70484125, 0.2589294 ],
       [0.00018858, 0.14637262, 0.8534388 ],
       [0.00081403, 0.29344714, 0.70573883],
       [0.000279  , 0.33023907, 0.66948192],
       [0.00045801, 0.33833991, 0.66120208],
       [0.00025341, 0.25571436, 0.74403222],
       [0.00006041, 0.38291757, 0.61702202],
       [0.00206351, 0.2798062 , 0.71813029],
       [0.00012312, 0.42493487, 0.57494202],
       [0.0001599 , 0.42361552, 0.57622458],
       [0.00035986, 0.1507501 , 0.84889004],
       [0.00301206, 0.27686098, 0.72012696],
       [0.00064689, 0.35551374, 0.64383936],
       [0.00068392, 0.29819313, 0.70112296],
       [0.00063265, 0.29576839, 0.70359896],
       [0.00061817, 0.17242508, 0.82695675],
       [0.0011076 , 0.17147616, 0.82741624],
       [0.00080141, 0.34867019, 0.6505284 ],
       [0.00019462, 0.23736752, 0.76243786],
       [0.00001303, 0.42032294, 0.57966403],
       [0.00067999, 0.47061514, 0.52870487],
       [0.00050811, 0.22392596, 0.77556593],
       [0.0013505 , 0.22852031, 0.77012918],
       [0.00003818, 0.42849746, 0.57146436],
       [0.00205514, 0.40046885, 0.59747602],
       [0.00068243, 0.23590885, 0.76340873],
       [0.00045642, 0.39783264, 0.60171094],
       [0.00320524, 0.38361431, 0.61318045],
       [0.00343776, 0.32673717, 0.66982508],
       [0.00030239, 0.29738763, 0.70230998],
       [0.00067575, 0.51161821, 0.48770605],
       [0.00016147, 0.42854842, 0.57129011],
       [0.00064593, 0.34460403, 0.65475004],
       [0.00027729, 0.27685415, 0.72286856],
       [0.00207367, 0.49125249, 0.50667385],
       [0.00035439, 0.44307281, 0.5565728 ],
       [0.00018237, 0.34196617, 0.65785146],
       [0.00063838, 0.1265659 , 0.87279573],
       [0.00092554, 0.32031446, 0.67876   ],
       [0.00431283, 0.31746907, 0.6782181 ],
       [0.00117132, 0.3003005 , 0.69852818],
       [0.00045021, 0.20080005, 0.79874973],
       [0.00216404, 0.24761373, 0.75022222],
       [0.00081403, 0.29344714, 0.70573883],
       [0.00029358, 0.22339673, 0.77630969],
       [0.00045525, 0.15204928, 0.84749547],
       [0.00116469, 0.23233015, 0.76650517],
       [0.0009204 , 0.37926299, 0.6198166 ],
       [0.00146455, 0.29758429, 0.70095116],
       [0.00110986, 0.12983185, 0.8690583 ],
       [0.00169379, 0.27997339, 0.71833282]])
In [28]:
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

clf2 = SVC(gamma='auto')
clf2.fit(iris.data, iris.target)
predicted = clf.predict(iris.data)
predicted2 = clf2.predict(iris.data)

print(accuracy_score(iris.target, predicted))
print(accuracy_score(iris.target, predicted2))
 
0.96
0.9866666666666667