第4章下 最基础的分类算法-k近邻算法 kNN

 

4-6 网格搜索与k近邻算法中更多超参数

 

 Notbook 示例

 

 Notbook 源码

 

复制代码
  1 [1]
  2 import numpy as np
  3 from sklearn import datasets
  4 [2]
  5 digits = datasets.load_digits()
  6 X = digits.data
  7 y = digits.target
  8 [3]
  9 from sklearn.model_selection import train_test_split
 10 
 11 X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3 ,random_state=111 )
 12 [4]
 13 from sklearn.neighbors import KNeighborsClassifier
 14 
 15 knn_clf = KNeighborsClassifier( n_neighbors = 6 )
 16 knn_clf.fit(X_train,y_train)
 17 knn_clf.score(X_test,y_test)
 18 0.9833333333333333
 19 Grid Search
 20 [5]
 21 param_gid = [
 22     {
 23         'weights': ['unifrom'],
 24         'n_neighbors': [ i for i in range(1,11)]
 25     },
 26     {
 27         'weights': ['distance'],
 28         'n_neighbors': [ i for i in range(1,11)],
 29         'p': [ i for i in range(1,6)]
 30     }
 31     
 32 ]
 33 [6]
 34 knn_clf = KNeighborsClassifier()
 35 [7]
 36 from sklearn.model_selection import GridSearchCV
 37 
 38 grid_search = GridSearchCV(knn_clf,param_gid)
 39 [8]
 40 %%time
 41 grid_search.fit(X_train,y_train)
 42 CPU times: total: 2min 15s
 43 Wall time: 2min 18s
 44 
 45 F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py:372: FitFailedWarning: 
 46 50 fits failed out of a total of 300.
 47 The score on these train-test partitions for these parameters will be set to nan.
 48 If these failures are not expected, you can try to debug them by setting error_score='raise'.
 49 
 50 Below are more details about the failures:
 51 --------------------------------------------------------------------------------
 52 50 fits failed with the following error:
 53 Traceback (most recent call last):
 54   File "F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score
 55     estimator.fit(X_train, y_train, **fit_params)
 56   File "F:\anaconda\lib\site-packages\sklearn\neighbors\_classification.py", line 196, in fit
 57     self.weights = _check_weights(self.weights)
 58   File "F:\anaconda\lib\site-packages\sklearn\neighbors\_base.py", line 82, in _check_weights
 59     raise ValueError(
 60 ValueError: weights not recognized: should be 'uniform', 'distance', or a callable function
 61 
 62   warnings.warn(some_fits_failed_message, FitFailedWarning)
 63 F:\anaconda\lib\site-packages\sklearn\model_selection\_search.py:969: UserWarning: One or more of the test scores are non-finite: [       nan        nan        nan        nan        nan        nan
 64         nan        nan        nan        nan 0.98011446 0.98965724
 65  0.98965724 0.99204452 0.98966041 0.98090811 0.98965724 0.98965724
 66  0.99204452 0.98966041 0.98408904 0.98726997 0.98646999 0.98726681
 67  0.98488585 0.98249542 0.98806678 0.98886359 0.98886359 0.98726997
 68  0.98249542 0.98647948 0.98966041 0.98726997 0.98488902 0.98249542
 69  0.98488585 0.9856795  0.98885727 0.9856795  0.98090179 0.98329539
 70  0.9856795  0.98806362 0.98488269 0.98010181 0.98170176 0.9856795
 71  0.98726997 0.98487637 0.97692405 0.98408904 0.98329223 0.98647948
 72  0.98488585 0.97851135 0.98010814 0.98488269 0.98726997 0.98726997]
 73   warnings.warn(
 74 
 75 GridSearchCV(estimator=KNeighborsClassifier(),
 76              param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 77                           'weights': ['unifrom']},
 78                          {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 79                           'p': [1, 2, 3, 4, 5], 'weights': ['distance']}])
 80 [9]
 81 grid_search.best_estimator_
 82 KNeighborsClassifier(n_neighbors=1, p=4, weights='distance')
 83 [10]
 84 grid_search.best_score_
 85 0.9920445203313729
 86 [11]
 87 grid_search.best_params_
 88 {'n_neighbors': 1, 'p': 4, 'weights': 'distance'}
 89 若 随机数为111,则 grid_search.best_score_ = 0.9920445203313729, grid_search.best_params_ = {'n_neighbors': 1, 'p': 4, 'weights': 'distance'}
 90 
 91 [12]
 92 knn_clf = grid_search.best_estimator_
 93 [13]
 94 knn_clf.predict(X_test)
 95 array([7, 1, 2, 9, 5, 8, 6, 4, 8, 3, 8, 4, 4, 5, 7, 1, 6, 1, 0, 6, 6, 8,
 96        8, 0, 8, 4, 5, 8, 0, 0, 3, 3, 5, 2, 1, 4, 8, 6, 7, 3, 3, 9, 6, 0,
 97        4, 9, 7, 3, 8, 7, 4, 3, 5, 0, 3, 1, 7, 6, 5, 7, 6, 0, 9, 7, 7, 8,
 98        2, 8, 6, 6, 1, 1, 2, 6, 4, 6, 4, 8, 6, 9, 8, 1, 3, 4, 4, 2, 0, 7,
 99        6, 0, 8, 2, 0, 5, 8, 5, 3, 3, 7, 4, 7, 3, 4, 2, 4, 9, 1, 8, 5, 1,
100        2, 7, 0, 2, 8, 9, 7, 5, 7, 7, 8, 8, 9, 2, 3, 9, 7, 7, 8, 2, 5, 3,
101        2, 4, 0, 1, 4, 8, 7, 9, 6, 8, 1, 5, 2, 6, 1, 4, 1, 6, 5, 3, 4, 2,
102        2, 7, 0, 7, 1, 5, 4, 6, 1, 7, 4, 9, 6, 8, 5, 8, 4, 3, 3, 2, 5, 6,
103        7, 9, 0, 2, 0, 5, 4, 8, 0, 8, 6, 9, 7, 3, 1, 9, 4, 2, 7, 9, 4, 0,
104        5, 2, 8, 2, 9, 1, 8, 5, 4, 5, 7, 7, 5, 5, 0, 1, 4, 4, 6, 5, 7, 6,
105        0, 6, 7, 1, 9, 0, 6, 1, 2, 9, 1, 5, 3, 0, 2, 1, 0, 9, 3, 4, 1, 0,
106        9, 9, 2, 0, 5, 3, 6, 5, 5, 3, 9, 1, 2, 8, 7, 4, 9, 8, 8, 1, 3, 1,
107        6, 3, 0, 7, 2, 4, 7, 2, 5, 0, 6, 4, 7, 4, 1, 0, 3, 1, 8, 0, 5, 6,
108        9, 5, 5, 0, 6, 0, 5, 2, 9, 7, 2, 9, 1, 0, 3, 5, 8, 8, 0, 4, 3, 4,
109        6, 1, 6, 1, 7, 3, 3, 2, 3, 6, 7, 1, 0, 1, 9, 6, 6, 6, 8, 2, 3, 5,
110        9, 4, 4, 5, 3, 9, 7, 1, 3, 0, 0, 8, 6, 9, 7, 9, 6, 4, 2, 7, 2, 6,
111        5, 4, 1, 7, 9, 0, 1, 1, 7, 5, 3, 3, 7, 4, 9, 0, 8, 6, 0, 9, 1, 9,
112        7, 8, 8, 8, 6, 2, 1, 3, 0, 2, 3, 6, 8, 1, 6, 1, 3, 9, 6, 2, 5, 2,
113        9, 7, 7, 6, 5, 8, 0, 1, 8, 6, 3, 5, 0, 4, 3, 9, 9, 3, 4, 3, 7, 9,
114        2, 3, 5, 3, 9, 3, 1, 4, 7, 7, 1, 7, 4, 3, 0, 8, 0, 9, 6, 3, 9, 8,
115        3, 9, 9, 9, 4, 1, 6, 7, 7, 2, 0, 1, 0, 7, 5, 7, 6, 1, 5, 0, 6, 9,
116        5, 1, 2, 1, 7, 5, 2, 1, 8, 1, 8, 8, 2, 8, 6, 8, 7, 0, 9, 9, 6, 2,
117        0, 9, 6, 3, 4, 3, 0, 8, 5, 4, 8, 6, 4, 5, 2, 5, 6, 1, 0, 5, 7, 0,
118        9, 5, 3, 2, 9, 3, 0, 6, 4, 8, 3, 2, 3, 6, 6, 8, 1, 9, 4, 3, 1, 1,
119        4, 5, 4, 3, 7, 5, 3, 3, 7, 8, 1, 0])
120 [14]
121 knn_clf.score(X_test,y_test)
122 0.9907407407407407
123 [15]
124 %%time
125 grid_search = GridSearchCV(knn_clf,param_gid,n_jobs= 4, verbose = 2)
126 grid_search.fit(X_train,y_train)
127 # 创建多个分类器来比较,可以并行处理,n_jobs 为分配核的数量,默认为单核 1 .-1为全核。
128 # verbose,及时输出一些信息,值越大越详细
129 Fitting 5 folds for each of 60 candidates, totalling 300 fits
130 CPU times: total: 484 ms
131 Wall time: 1min 28s
132 
133 F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py:372: FitFailedWarning: 
134 50 fits failed out of a total of 300.
135 The score on these train-test partitions for these parameters will be set to nan.
136 If these failures are not expected, you can try to debug them by setting error_score='raise'.
137 
138 Below are more details about the failures:
139 --------------------------------------------------------------------------------
140 50 fits failed with the following error:
141 Traceback (most recent call last):
142   File "F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score
143     estimator.fit(X_train, y_train, **fit_params)
144   File "F:\anaconda\lib\site-packages\sklearn\neighbors\_classification.py", line 196, in fit
145     self.weights = _check_weights(self.weights)
146   File "F:\anaconda\lib\site-packages\sklearn\neighbors\_base.py", line 82, in _check_weights
147     raise ValueError(
148 ValueError: weights not recognized: should be 'uniform', 'distance', or a callable function
149 
150   warnings.warn(some_fits_failed_message, FitFailedWarning)
151 F:\anaconda\lib\site-packages\sklearn\model_selection\_search.py:969: UserWarning: One or more of the test scores are non-finite: [       nan        nan        nan        nan        nan        nan
152         nan        nan        nan        nan 0.98011446 0.98965724
153  0.98965724 0.99204452 0.98966041 0.98090811 0.98965724 0.98965724
154  0.99204452 0.98966041 0.98408904 0.98726997 0.98646999 0.98726681
155  0.98488585 0.98249542 0.98806678 0.98886359 0.98886359 0.98726997
156  0.98249542 0.98647948 0.98966041 0.98726997 0.98488902 0.98249542
157  0.98488585 0.9856795  0.98885727 0.9856795  0.98090179 0.98329539
158  0.9856795  0.98806362 0.98488269 0.98010181 0.98170176 0.9856795
159  0.98726997 0.98487637 0.97692405 0.98408904 0.98329223 0.98647948
160  0.98488585 0.97851135 0.98010814 0.98488269 0.98726997 0.98726997]
161   warnings.warn(
162 
163 GridSearchCV(estimator=KNeighborsClassifier(n_neighbors=1, p=4,
164                                             weights='distance'),
165              n_jobs=4,
166              param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
167                           'weights': ['unifrom']},
168                          {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
169                           'p': [1, 2, 3, 4, 5], 'weights': ['distance']}],
170              verbose=2)
复制代码

 

 

 

4-7 数据归一化

 

 

 

 

 

 

 

 

 

 

 Notbook 示例

 

 

 

 

 

 

Notbook 源码

复制代码
  1 数据归一化处理
  2 [1]
  3 import numpy as np
  4 import matplotlib.pyplot as plt
  5 最值归一化 Normalization
  6 [2]
  7 x = np.random.randint(0,100, size = 100)
  8 [3]
  9 x
 10 array([ 2, 58, 55, 40, 68,  7, 72,  4, 50, 89, 96, 19, 71,  6, 41, 40, 63,
 11         4,  2, 26, 79, 62, 62, 92, 49, 46, 75, 47, 44, 91, 40, 67, 38, 14,
 12        13,  0, 93, 15, 20, 50, 38, 31, 41, 22, 36, 85, 64, 87, 98, 65, 31,
 13        10, 46, 22, 86, 24, 68, 25, 11, 31, 22, 87, 84, 18, 58, 87,  4, 15,
 14        64, 92, 70, 59, 74, 81, 47, 33,  1, 93,  6, 37, 62, 17, 58, 56, 98,
 15        53,  2, 70, 36, 17, 21, 66,  1, 79, 60, 71, 89, 71, 61, 30])
 16 [5]
 17 ( x - np.min(x)) / (np.max(x) - np.min(x))
 18 array([0.02040816, 0.59183673, 0.56122449, 0.40816327, 0.69387755,
 19        0.07142857, 0.73469388, 0.04081633, 0.51020408, 0.90816327,
 20        0.97959184, 0.19387755, 0.7244898 , 0.06122449, 0.41836735,
 21        0.40816327, 0.64285714, 0.04081633, 0.02040816, 0.26530612,
 22        0.80612245, 0.63265306, 0.63265306, 0.93877551, 0.5       ,
 23        0.46938776, 0.76530612, 0.47959184, 0.44897959, 0.92857143,
 24        0.40816327, 0.68367347, 0.3877551 , 0.14285714, 0.13265306,
 25        0.        , 0.94897959, 0.15306122, 0.20408163, 0.51020408,
 26        0.3877551 , 0.31632653, 0.41836735, 0.2244898 , 0.36734694,
 27        0.86734694, 0.65306122, 0.8877551 , 1.        , 0.66326531,
 28        0.31632653, 0.10204082, 0.46938776, 0.2244898 , 0.87755102,
 29        0.24489796, 0.69387755, 0.25510204, 0.1122449 , 0.31632653,
 30        0.2244898 , 0.8877551 , 0.85714286, 0.18367347, 0.59183673,
 31        0.8877551 , 0.04081633, 0.15306122, 0.65306122, 0.93877551,
 32        0.71428571, 0.60204082, 0.75510204, 0.82653061, 0.47959184,
 33        0.33673469, 0.01020408, 0.94897959, 0.06122449, 0.37755102,
 34        0.63265306, 0.17346939, 0.59183673, 0.57142857, 1.        ,
 35        0.54081633, 0.02040816, 0.71428571, 0.36734694, 0.17346939,
 36        0.21428571, 0.67346939, 0.01020408, 0.80612245, 0.6122449 ,
 37        0.7244898 , 0.90816327, 0.7244898 , 0.62244898, 0.30612245])
 38 [6]
 39 X = np.random.randint(0,100,(50,2))
 40 [7]
 41 X[:10,:]
 42 array([[19, 14],
 43        [23, 82],
 44        [ 4, 17],
 45        [44, 58],
 46        [23, 91],
 47        [46, 17],
 48        [34, 25],
 49        [29, 39],
 50        [69, 61],
 51        [70, 25]])
 52 [9]
 53 X = np.array(X,dtype = float)
 54 [10]
 55 X[:10,:]
 56 array([[19., 14.],
 57        [23., 82.],
 58        [ 4., 17.],
 59        [44., 58.],
 60        [23., 91.],
 61        [46., 17.],
 62        [34., 25.],
 63        [29., 39.],
 64        [69., 61.],
 65        [70., 25.]])
 66 [18]
 67 X[:,0] = (X[:,0] - np.min(X[:,0])) / ( np.max(X[:,0]) - np.min(X[:,0]))
 68 [19]
 69 X[:,1] = (X[:,1] - np.min(X[:,1])) / ( np.max(X[:,1]) - np.min(X[:,1]))
 70 [20]
 71 X[:10,:]
 72 array([[0.19191919, 0.11458333],
 73        [0.23232323, 0.82291667],
 74        [0.04040404, 0.14583333],
 75        [0.44444444, 0.57291667],
 76        [0.23232323, 0.91666667],
 77        [0.46464646, 0.14583333],
 78        [0.34343434, 0.22916667],
 79        [0.29292929, 0.375     ],
 80        [0.6969697 , 0.60416667],
 81        [0.70707071, 0.22916667]])
 82 [21]
 83 plt.scatter(X[:,0],X[:,1])
 84 <matplotlib.collections.PathCollection at 0x1f514d5c2e0>
 85 
 86 [22]
 87 np.mean(X[:,0])
 88 0.4503030303030303
 89 [26]
 90 np.std(X[:,0])
 91 0.32224653972392703
 92 [24]
 93 np.mean(X[:,1])
 94 0.4503030303030303
 95 [27]
 96 np.std(X[:,1])
 97 0.3004160887650475
 98 均值方差归一化
 99 [28]
100 X2  =  np.random.randint(0,100,(50,2))
101 [29]
102 X2 = np.array(X2,dtype = float)
103 [36]
104 X2[:,0] = (X2[:,0] - np.mean(X2[:,0])) /  np.std(X2[:,0])
105 [37]
106 X2[:,1] = (X2[:,1] - np.mean(X2[:,1])) /  np.std(X2[:,1])
107 [38]
108 plt.scatter(X2[:,0],X2[:,1])
109 <matplotlib.collections.PathCollection at 0x1f517fa1430>
110 
111 [39]
112 np.mean(X2[:,0])
113 0.0
114 [40]
115 np.std(X2[:,0])
116 1.0
117 [41]
118 np.mean(X2[:,1])
119 -4.4408920985006264e-17
120 [42]
121 np.std(X2[:,1])
122 1.0
复制代码

 

4-8 scikit-learn中的Scaler

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

Notbook 示例

 

 

 

notbook 源码

 

复制代码
  1 [1]
  2 import numpy as np
  3 from sklearn import datasets
  4 [2]
  5 iris = datasets.load_iris()
  6 [3]
  7 X = iris.data
  8 y = iris.target
  9 [4]
 10 X[:10,:]
 11 array([[5.1, 3.5, 1.4, 0.2],
 12        [4.9, 3. , 1.4, 0.2],
 13        [4.7, 3.2, 1.3, 0.2],
 14        [4.6, 3.1, 1.5, 0.2],
 15        [5. , 3.6, 1.4, 0.2],
 16        [5.4, 3.9, 1.7, 0.4],
 17        [4.6, 3.4, 1.4, 0.3],
 18        [5. , 3.4, 1.5, 0.2],
 19        [4.4, 2.9, 1.4, 0.2],
 20        [4.9, 3.1, 1.5, 0.1]])
 21 [5]
 22 from sklearn.model_selection import train_test_split
 23 X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3 ,random_state=666 )
 24 scikit-learn中的StandardScaler
 25 [6]
 26 from sklearn.preprocessing import StandardScaler
 27 [7]
 28 standardScaler = StandardScaler()
 29 [8]
 30 standardScaler.fit(X_train)
 31 StandardScaler()
 32 [9]
 33 standardScaler.mean_
 34 array([5.81619048, 3.08761905, 3.66952381, 1.15714286])
 35 [10]
 36 standardScaler.scale_ #   .std 旧形式已经弃用 , scale表示数据分布范围
 37 array([0.80747977, 0.43789436, 1.76166176, 0.75464998])
 38 [11]
 39 standardScaler.transform(X_train)
 40 array([[-0.63926119,  1.39846731, -1.2315212 , -1.26832688],
 41        [-1.01078752,  0.94173615, -1.17475662, -0.73827982],
 42        [-1.75384019, -0.42845732, -1.28828579, -1.26832688],
 43        [-0.02005063, -0.88518848,  0.13082885,  0.05679076],
 44        [-0.7631033 ,  0.71337057, -1.28828579, -1.26832688],
 45        [-1.50615597,  0.71337057, -1.28828579, -1.13581511],
 46        [ 0.84684415,  0.25663941,  0.81200388,  1.11688486],
 47        [-0.14389274, -0.42845732,  0.30112261,  0.18930252],
 48        [ 0.97068626, -0.20009175,  0.41465178,  0.32181428],
 49        [ 0.2276336 , -0.42845732,  0.47141637,  0.45432605],
 50        [-1.38231385,  0.25663941, -1.17475662, -1.26832688],
 51        [-1.13462963,  1.17010173, -1.28828579, -1.40083864],
 52        [ 1.09452838,  0.02827383,  1.09582681,  1.64693191],
 53        [ 0.59915993, -0.88518848,  0.69847471,  0.85186133],
 54        [ 0.35147571, -0.6568229 ,  0.58494554,  0.05679076],
 55        [ 0.47531782, -0.6568229 ,  0.64171013,  0.85186133],
 56        [-0.14389274,  2.99702636, -1.2315212 , -1.00330335],
 57        [ 0.59915993, -1.34191964,  0.69847471,  0.45432605],
 58        [ 0.72300204, -0.42845732,  0.3578872 ,  0.18930252],
 59        [-0.88694541,  1.62683289, -1.00446286, -1.00330335],
 60        [ 1.21837049, -0.6568229 ,  0.64171013,  0.32181428],
 61        [-0.88694541,  0.94173615, -1.28828579, -1.13581511],
 62        [-1.8776823 , -0.20009175, -1.45857955, -1.40083864],
 63        [ 0.10379148, -0.20009175,  0.81200388,  0.85186133],
 64        [ 0.72300204, -0.6568229 ,  1.09582681,  1.24939662],
 65        [-0.26773485, -0.6568229 ,  0.69847471,  1.11688486],
 66        [-0.39157696, -1.57028522,  0.01729968, -0.20823277],
 67        [ 1.3422126 ,  0.02827383,  0.69847471,  0.45432605],
 68        [ 0.59915993,  0.71337057,  1.09582681,  1.64693191],
 69        [ 0.84684415, -0.20009175,  1.20935598,  1.38190839],
 70        [-0.14389274,  1.62683289, -1.11799203, -1.13581511],
 71        [ 0.97068626, -0.42845732,  0.52818096,  0.18930252],
 72        [ 1.09452838,  0.48500499,  1.1525914 ,  1.77944368],
 73        [-1.25847174, -0.20009175, -1.28828579, -1.40083864],
 74        [-1.01078752,  1.17010173, -1.28828579, -1.26832688],
 75        [ 0.2276336 , -0.20009175,  0.64171013,  0.85186133],
 76        [-1.01078752, -0.20009175, -1.17475662, -1.26832688],
 77        [ 0.35147571, -0.20009175,  0.69847471,  0.85186133],
 78        [ 0.72300204,  0.02827383,  1.03906223,  0.85186133],
 79        [-0.88694541,  1.39846731, -1.2315212 , -1.00330335],
 80        [-0.14389274, -0.20009175,  0.30112261,  0.05679076],
 81        [-1.01078752,  0.94173615, -1.34505037, -1.13581511],
 82        [-0.88694541,  1.62683289, -1.2315212 , -1.13581511],
 83        [-1.50615597,  0.25663941, -1.28828579, -1.26832688],
 84        [-0.51541907, -0.20009175,  0.47141637,  0.45432605],
 85        [ 0.84684415, -0.6568229 ,  0.52818096,  0.45432605],
 86        [ 0.35147571, -0.6568229 ,  0.18759344,  0.18930252],
 87        [-1.25847174,  0.71337057, -1.17475662, -1.26832688],
 88        [-0.88694541,  0.48500499, -1.11799203, -0.87079159],
 89        [-0.02005063, -0.88518848,  0.81200388,  0.9843731 ],
 90        [-0.26773485, -0.20009175,  0.24435803,  0.18930252],
 91        [ 0.59915993, -0.6568229 ,  0.81200388,  0.45432605],
 92        [ 1.09452838,  0.48500499,  1.1525914 ,  1.24939662],
 93        [ 1.71373893, -0.20009175,  1.20935598,  0.58683781],
 94        [ 1.09452838, -0.20009175,  0.86876847,  1.51442015],
 95        [-1.13462963,  0.02827383, -1.2315212 , -1.40083864],
 96        [-1.13462963, -1.34191964,  0.47141637,  0.71934957],
 97        [-0.14389274, -1.34191964,  0.7552393 ,  1.11688486],
 98        [-1.13462963, -1.57028522, -0.20975866, -0.20823277],
 99        [-0.39157696, -1.57028522,  0.07406427, -0.07572101],
100        [ 1.09452838, -1.34191964,  1.20935598,  0.85186133],
101        [ 0.84684415, -0.20009175,  1.03906223,  0.85186133],
102        [-0.14389274, -1.11355406, -0.09622949, -0.20823277],
103        [ 0.2276336 , -2.02701638,  0.7552393 ,  0.45432605],
104        [ 1.09452838,  0.02827383,  0.58494554,  0.45432605],
105        [-1.13462963,  0.02827383, -1.2315212 , -1.26832688],
106        [ 0.59915993, -1.34191964,  0.7552393 ,  0.9843731 ],
107        [-1.38231385,  0.25663941, -1.34505037, -1.26832688],
108        [ 0.2276336 , -0.88518848,  0.81200388,  0.58683781],
109        [-0.02005063, -1.11355406,  0.18759344,  0.05679076],
110        [ 1.3422126 ,  0.25663941,  1.1525914 ,  1.51442015],
111        [-1.75384019, -0.20009175, -1.34505037, -1.26832688],
112        [ 1.58989682, -0.20009175,  1.26612057,  1.24939662],
113        [ 1.21837049,  0.25663941,  1.26612057,  1.51442015],
114        [-0.7631033 ,  0.94173615, -1.2315212 , -1.26832688],
115        [ 2.58063371,  1.62683289,  1.5499435 ,  1.11688486],
116        [ 0.72300204, -0.6568229 ,  1.09582681,  1.38190839],
117        [-0.26773485, -0.42845732, -0.0394649 ,  0.18930252],
118        [-0.39157696,  2.5402952 , -1.28828579, -1.26832688],
119        [-1.25847174, -0.20009175, -1.28828579, -1.13581511],
120        [ 0.59915993, -0.42845732,  1.09582681,  0.85186133],
121        [-1.75384019,  0.25663941, -1.34505037, -1.26832688],
122        [-0.51541907,  1.85519847, -1.11799203, -1.00330335],
123        [-1.01078752,  0.71337057, -1.17475662, -1.00330335],
124        [ 1.09452838, -0.20009175,  0.7552393 ,  0.71934957],
125        [-0.51541907,  1.85519847, -1.34505037, -1.00330335],
126        [ 2.33294949, -0.6568229 ,  1.72023726,  1.11688486],
127        [-0.26773485, -0.88518848,  0.30112261,  0.18930252],
128        [ 1.21837049, -0.20009175,  1.03906223,  1.24939662],
129        [-0.39157696,  0.94173615, -1.34505037, -1.26832688],
130        [-1.25847174,  0.71337057, -1.00446286, -1.26832688],
131        [-0.51541907,  0.71337057, -1.11799203, -1.26832688],
132        [ 2.33294949,  1.62683289,  1.72023726,  1.38190839],
133        [ 1.3422126 ,  0.02827383,  0.98229764,  1.24939662],
134        [-0.26773485, -1.34191964,  0.13082885, -0.07572101],
135        [-0.88694541,  0.71337057, -1.2315212 , -1.26832688],
136        [-0.88694541,  1.62683289, -1.17475662, -1.26832688],
137        [ 0.35147571, -0.42845732,  0.58494554,  0.32181428],
138        [-0.02005063,  2.08356405, -1.40181496, -1.26832688],
139        [-1.01078752, -2.48374754, -0.09622949, -0.20823277],
140        [ 0.72300204,  0.25663941,  0.47141637,  0.45432605],
141        [ 0.35147571, -0.20009175,  0.52818096,  0.32181428],
142        [ 0.10379148,  0.25663941,  0.64171013,  0.85186133],
143        [ 0.2276336 , -2.02701638,  0.18759344, -0.20823277],
144        [ 1.96142316, -0.6568229 ,  1.37964974,  0.9843731 ]])
145 [12]
146 X_train
147 array([[5.3, 3.7, 1.5, 0.2],
148        [5. , 3.5, 1.6, 0.6],
149        [4.4, 2.9, 1.4, 0.2],
150        [5.8, 2.7, 3.9, 1.2],
151        [5.2, 3.4, 1.4, 0.2],
152        [4.6, 3.4, 1.4, 0.3],
153        [6.5, 3.2, 5.1, 2. ],
154        [5.7, 2.9, 4.2, 1.3],
155        [6.6, 3. , 4.4, 1.4],
156        [6. , 2.9, 4.5, 1.5],
157        [4.7, 3.2, 1.6, 0.2],
158        [4.9, 3.6, 1.4, 0.1],
159        [6.7, 3.1, 5.6, 2.4],
160        [6.3, 2.7, 4.9, 1.8],
161        [6.1, 2.8, 4.7, 1.2],
162        [6.2, 2.8, 4.8, 1.8],
163        [5.7, 4.4, 1.5, 0.4],
164        [6.3, 2.5, 4.9, 1.5],
165        [6.4, 2.9, 4.3, 1.3],
166        [5.1, 3.8, 1.9, 0.4],
167        [6.8, 2.8, 4.8, 1.4],
168        [5.1, 3.5, 1.4, 0.3],
169        [4.3, 3. , 1.1, 0.1],
170        [5.9, 3. , 5.1, 1.8],
171        [6.4, 2.8, 5.6, 2.1],
172        [5.6, 2.8, 4.9, 2. ],
173        [5.5, 2.4, 3.7, 1. ],
174        [6.9, 3.1, 4.9, 1.5],
175        [6.3, 3.4, 5.6, 2.4],
176        [6.5, 3. , 5.8, 2.2],
177        [5.7, 3.8, 1.7, 0.3],
178        [6.6, 2.9, 4.6, 1.3],
179        [6.7, 3.3, 5.7, 2.5],
180        [4.8, 3. , 1.4, 0.1],
181        [5. , 3.6, 1.4, 0.2],
182        [6. , 3. , 4.8, 1.8],
183        [5. , 3. , 1.6, 0.2],
184        [6.1, 3. , 4.9, 1.8],
185        [6.4, 3.1, 5.5, 1.8],
186        [5.1, 3.7, 1.5, 0.4],
187        [5.7, 3. , 4.2, 1.2],
188        [5. , 3.5, 1.3, 0.3],
189        [5.1, 3.8, 1.5, 0.3],
190        [4.6, 3.2, 1.4, 0.2],
191        [5.4, 3. , 4.5, 1.5],
192        [6.5, 2.8, 4.6, 1.5],
193        [6.1, 2.8, 4. , 1.3],
194        [4.8, 3.4, 1.6, 0.2],
195        [5.1, 3.3, 1.7, 0.5],
196        [5.8, 2.7, 5.1, 1.9],
197        [5.6, 3. , 4.1, 1.3],
198        [6.3, 2.8, 5.1, 1.5],
199        [6.7, 3.3, 5.7, 2.1],
200        [7.2, 3. , 5.8, 1.6],
201        [6.7, 3. , 5.2, 2.3],
202        [4.9, 3.1, 1.5, 0.1],
203        [4.9, 2.5, 4.5, 1.7],
204        [5.7, 2.5, 5. , 2. ],
205        [4.9, 2.4, 3.3, 1. ],
206        [5.5, 2.4, 3.8, 1.1],
207        [6.7, 2.5, 5.8, 1.8],
208        [6.5, 3. , 5.5, 1.8],
209        [5.7, 2.6, 3.5, 1. ],
210        [6. , 2.2, 5. , 1.5],
211        [6.7, 3.1, 4.7, 1.5],
212        [4.9, 3.1, 1.5, 0.2],
213        [6.3, 2.5, 5. , 1.9],
214        [4.7, 3.2, 1.3, 0.2],
215        [6. , 2.7, 5.1, 1.6],
216        [5.8, 2.6, 4. , 1.2],
217        [6.9, 3.2, 5.7, 2.3],
218        [4.4, 3. , 1.3, 0.2],
219        [7.1, 3. , 5.9, 2.1],
220        [6.8, 3.2, 5.9, 2.3],
221        [5.2, 3.5, 1.5, 0.2],
222        [7.9, 3.8, 6.4, 2. ],
223        [6.4, 2.8, 5.6, 2.2],
224        [5.6, 2.9, 3.6, 1.3],
225        [5.5, 4.2, 1.4, 0.2],
226        [4.8, 3. , 1.4, 0.3],
227        [6.3, 2.9, 5.6, 1.8],
228        [4.4, 3.2, 1.3, 0.2],
229        [5.4, 3.9, 1.7, 0.4],
230        [5. , 3.4, 1.6, 0.4],
231        [6.7, 3. , 5. , 1.7],
232        [5.4, 3.9, 1.3, 0.4],
233        [7.7, 2.8, 6.7, 2. ],
234        [5.6, 2.7, 4.2, 1.3],
235        [6.8, 3. , 5.5, 2.1],
236        [5.5, 3.5, 1.3, 0.2],
237        [4.8, 3.4, 1.9, 0.2],
238        [5.4, 3.4, 1.7, 0.2],
239        [7.7, 3.8, 6.7, 2.2],
240        [6.9, 3.1, 5.4, 2.1],
241        [5.6, 2.5, 3.9, 1.1],
242        [5.1, 3.4, 1.5, 0.2],
243        [5.1, 3.8, 1.6, 0.2],
244        [6.1, 2.9, 4.7, 1.4],
245        [5.8, 4. , 1.2, 0.2],
246        [5. , 2. , 3.5, 1. ],
247        [6.4, 3.2, 4.5, 1.5],
248        [6.1, 3. , 4.6, 1.4],
249        [5.9, 3.2, 4.8, 1.8],
250        [6. , 2.2, 4. , 1. ],
251        [7.4, 2.8, 6.1, 1.9]])
252 [13]
253 X_train = standardScaler.transform(X_train)
254 [14]
255 X_train
256 array([[-0.63926119,  1.39846731, -1.2315212 , -1.26832688],
257        [-1.01078752,  0.94173615, -1.17475662, -0.73827982],
258        [-1.75384019, -0.42845732, -1.28828579, -1.26832688],
259        [-0.02005063, -0.88518848,  0.13082885,  0.05679076],
260        [-0.7631033 ,  0.71337057, -1.28828579, -1.26832688],
261        [-1.50615597,  0.71337057, -1.28828579, -1.13581511],
262        [ 0.84684415,  0.25663941,  0.81200388,  1.11688486],
263        [-0.14389274, -0.42845732,  0.30112261,  0.18930252],
264        [ 0.97068626, -0.20009175,  0.41465178,  0.32181428],
265        [ 0.2276336 , -0.42845732,  0.47141637,  0.45432605],
266        [-1.38231385,  0.25663941, -1.17475662, -1.26832688],
267        [-1.13462963,  1.17010173, -1.28828579, -1.40083864],
268        [ 1.09452838,  0.02827383,  1.09582681,  1.64693191],
269        [ 0.59915993, -0.88518848,  0.69847471,  0.85186133],
270        [ 0.35147571, -0.6568229 ,  0.58494554,  0.05679076],
271        [ 0.47531782, -0.6568229 ,  0.64171013,  0.85186133],
272        [-0.14389274,  2.99702636, -1.2315212 , -1.00330335],
273        [ 0.59915993, -1.34191964,  0.69847471,  0.45432605],
274        [ 0.72300204, -0.42845732,  0.3578872 ,  0.18930252],
275        [-0.88694541,  1.62683289, -1.00446286, -1.00330335],
276        [ 1.21837049, -0.6568229 ,  0.64171013,  0.32181428],
277        [-0.88694541,  0.94173615, -1.28828579, -1.13581511],
278        [-1.8776823 , -0.20009175, -1.45857955, -1.40083864],
279        [ 0.10379148, -0.20009175,  0.81200388,  0.85186133],
280        [ 0.72300204, -0.6568229 ,  1.09582681,  1.24939662],
281        [-0.26773485, -0.6568229 ,  0.69847471,  1.11688486],
282        [-0.39157696, -1.57028522,  0.01729968, -0.20823277],
283        [ 1.3422126 ,  0.02827383,  0.69847471,  0.45432605],
284        [ 0.59915993,  0.71337057,  1.09582681,  1.64693191],
285        [ 0.84684415, -0.20009175,  1.20935598,  1.38190839],
286        [-0.14389274,  1.62683289, -1.11799203, -1.13581511],
287        [ 0.97068626, -0.42845732,  0.52818096,  0.18930252],
288        [ 1.09452838,  0.48500499,  1.1525914 ,  1.77944368],
289        [-1.25847174, -0.20009175, -1.28828579, -1.40083864],
290        [-1.01078752,  1.17010173, -1.28828579, -1.26832688],
291        [ 0.2276336 , -0.20009175,  0.64171013,  0.85186133],
292        [-1.01078752, -0.20009175, -1.17475662, -1.26832688],
293        [ 0.35147571, -0.20009175,  0.69847471,  0.85186133],
294        [ 0.72300204,  0.02827383,  1.03906223,  0.85186133],
295        [-0.88694541,  1.39846731, -1.2315212 , -1.00330335],
296        [-0.14389274, -0.20009175,  0.30112261,  0.05679076],
297        [-1.01078752,  0.94173615, -1.34505037, -1.13581511],
298        [-0.88694541,  1.62683289, -1.2315212 , -1.13581511],
299        [-1.50615597,  0.25663941, -1.28828579, -1.26832688],
300        [-0.51541907, -0.20009175,  0.47141637,  0.45432605],
301        [ 0.84684415, -0.6568229 ,  0.52818096,  0.45432605],
302        [ 0.35147571, -0.6568229 ,  0.18759344,  0.18930252],
303        [-1.25847174,  0.71337057, -1.17475662, -1.26832688],
304        [-0.88694541,  0.48500499, -1.11799203, -0.87079159],
305        [-0.02005063, -0.88518848,  0.81200388,  0.9843731 ],
306        [-0.26773485, -0.20009175,  0.24435803,  0.18930252],
307        [ 0.59915993, -0.6568229 ,  0.81200388,  0.45432605],
308        [ 1.09452838,  0.48500499,  1.1525914 ,  1.24939662],
309        [ 1.71373893, -0.20009175,  1.20935598,  0.58683781],
310        [ 1.09452838, -0.20009175,  0.86876847,  1.51442015],
311        [-1.13462963,  0.02827383, -1.2315212 , -1.40083864],
312        [-1.13462963, -1.34191964,  0.47141637,  0.71934957],
313        [-0.14389274, -1.34191964,  0.7552393 ,  1.11688486],
314        [-1.13462963, -1.57028522, -0.20975866, -0.20823277],
315        [-0.39157696, -1.57028522,  0.07406427, -0.07572101],
316        [ 1.09452838, -1.34191964,  1.20935598,  0.85186133],
317        [ 0.84684415, -0.20009175,  1.03906223,  0.85186133],
318        [-0.14389274, -1.11355406, -0.09622949, -0.20823277],
319        [ 0.2276336 , -2.02701638,  0.7552393 ,  0.45432605],
320        [ 1.09452838,  0.02827383,  0.58494554,  0.45432605],
321        [-1.13462963,  0.02827383, -1.2315212 , -1.26832688],
322        [ 0.59915993, -1.34191964,  0.7552393 ,  0.9843731 ],
323        [-1.38231385,  0.25663941, -1.34505037, -1.26832688],
324        [ 0.2276336 , -0.88518848,  0.81200388,  0.58683781],
325        [-0.02005063, -1.11355406,  0.18759344,  0.05679076],
326        [ 1.3422126 ,  0.25663941,  1.1525914 ,  1.51442015],
327        [-1.75384019, -0.20009175, -1.34505037, -1.26832688],
328        [ 1.58989682, -0.20009175,  1.26612057,  1.24939662],
329        [ 1.21837049,  0.25663941,  1.26612057,  1.51442015],
330        [-0.7631033 ,  0.94173615, -1.2315212 , -1.26832688],
331        [ 2.58063371,  1.62683289,  1.5499435 ,  1.11688486],
332        [ 0.72300204, -0.6568229 ,  1.09582681,  1.38190839],
333        [-0.26773485, -0.42845732, -0.0394649 ,  0.18930252],
334        [-0.39157696,  2.5402952 , -1.28828579, -1.26832688],
335        [-1.25847174, -0.20009175, -1.28828579, -1.13581511],
336        [ 0.59915993, -0.42845732,  1.09582681,  0.85186133],
337        [-1.75384019,  0.25663941, -1.34505037, -1.26832688],
338        [-0.51541907,  1.85519847, -1.11799203, -1.00330335],
339        [-1.01078752,  0.71337057, -1.17475662, -1.00330335],
340        [ 1.09452838, -0.20009175,  0.7552393 ,  0.71934957],
341        [-0.51541907,  1.85519847, -1.34505037, -1.00330335],
342        [ 2.33294949, -0.6568229 ,  1.72023726,  1.11688486],
343        [-0.26773485, -0.88518848,  0.30112261,  0.18930252],
344        [ 1.21837049, -0.20009175,  1.03906223,  1.24939662],
345        [-0.39157696,  0.94173615, -1.34505037, -1.26832688],
346        [-1.25847174,  0.71337057, -1.00446286, -1.26832688],
347        [-0.51541907,  0.71337057, -1.11799203, -1.26832688],
348        [ 2.33294949,  1.62683289,  1.72023726,  1.38190839],
349        [ 1.3422126 ,  0.02827383,  0.98229764,  1.24939662],
350        [-0.26773485, -1.34191964,  0.13082885, -0.07572101],
351        [-0.88694541,  0.71337057, -1.2315212 , -1.26832688],
352        [-0.88694541,  1.62683289, -1.17475662, -1.26832688],
353        [ 0.35147571, -0.42845732,  0.58494554,  0.32181428],
354        [-0.02005063,  2.08356405, -1.40181496, -1.26832688],
355        [-1.01078752, -2.48374754, -0.09622949, -0.20823277],
356        [ 0.72300204,  0.25663941,  0.47141637,  0.45432605],
357        [ 0.35147571, -0.20009175,  0.52818096,  0.32181428],
358        [ 0.10379148,  0.25663941,  0.64171013,  0.85186133],
359        [ 0.2276336 , -2.02701638,  0.18759344, -0.20823277],
360        [ 1.96142316, -0.6568229 ,  1.37964974,  0.9843731 ]])
361 [15]
362 X_test_standard = standardScaler.transform(X_test)
363 [16]
364 X_test_standard
365 array([[-0.26773485, -0.20009175,  0.47141637,  0.45432605],
366        [-0.02005063, -0.6568229 ,  0.81200388,  1.64693191],
367        [-1.01078752, -1.7986508 , -0.20975866, -0.20823277],
368        [-0.02005063, -0.88518848,  0.81200388,  0.9843731 ],
369        [-1.50615597,  0.02827383, -1.2315212 , -1.26832688],
370        [-0.39157696, -1.34191964,  0.18759344,  0.18930252],
371        [-0.14389274, -0.6568229 ,  0.47141637,  0.18930252],
372        [ 0.84684415, -0.20009175,  0.86876847,  1.11688486],
373        [ 0.59915993, -1.7986508 ,  0.41465178,  0.18930252],
374        [-0.39157696, -1.11355406,  0.41465178,  0.05679076],
375        [ 1.09452838,  0.02827383,  0.41465178,  0.32181428],
376        [-1.62999808, -1.7986508 , -1.34505037, -1.13581511],
377        [-1.25847174,  0.02827383, -1.17475662, -1.26832688],
378        [-0.51541907,  0.71337057, -1.2315212 , -1.00330335],
379        [ 1.71373893,  1.17010173,  1.37964974,  1.77944368],
380        [-0.02005063, -0.88518848,  0.24435803, -0.20823277],
381        [-1.50615597,  1.17010173, -1.51534413, -1.26832688],
382        [ 1.71373893,  0.25663941,  1.32288516,  0.85186133],
383        [ 1.3422126 ,  0.02827383,  0.81200388,  1.51442015],
384        [ 0.72300204, -0.88518848,  0.92553306,  0.9843731 ],
385        [ 0.59915993,  0.48500499,  0.58494554,  0.58683781],
386        [-1.01078752,  0.71337057, -1.2315212 , -1.26832688],
387        [ 2.33294949, -1.11355406,  1.83376643,  1.51442015],
388        [-1.01078752,  0.48500499, -1.28828579, -1.26832688],
389        [ 0.47531782, -0.42845732,  0.3578872 ,  0.18930252],
390        [ 0.10379148, -0.20009175,  0.30112261,  0.45432605],
391        [-1.01078752,  0.25663941, -1.40181496, -1.26832688],
392        [-0.39157696, -1.7986508 ,  0.18759344,  0.18930252],
393        [ 0.59915993,  0.48500499,  1.32288516,  1.77944368],
394        [ 2.33294949, -0.20009175,  1.37964974,  1.51442015],
395        [-0.88694541,  0.94173615, -1.28828579, -1.26832688],
396        [-1.13462963, -0.20009175, -1.28828579, -1.26832688],
397        [-0.14389274, -0.6568229 ,  0.24435803,  0.18930252],
398        [ 0.47531782,  0.71337057,  0.98229764,  1.51442015],
399        [-0.88694541, -1.34191964, -0.38005242, -0.07572101],
400        [ 1.46605471,  0.25663941,  0.58494554,  0.32181428],
401        [ 0.35147571, -1.11355406,  1.09582681,  0.32181428],
402        [ 2.20910738, -0.20009175,  1.66347267,  1.24939662],
403        [-0.7631033 ,  2.31192962, -1.2315212 , -1.40083864],
404        [ 0.47531782, -2.02701638,  0.47141637,  0.45432605],
405        [ 1.83758104, -0.42845732,  1.49317891,  0.85186133],
406        [ 0.72300204,  0.25663941,  0.92553306,  1.51442015],
407        [ 0.2276336 ,  0.71337057,  0.47141637,  0.58683781],
408        [-0.7631033 , -0.88518848,  0.13082885,  0.32181428],
409        [-0.51541907,  1.39846731, -1.2315212 , -1.26832688]])
410 [17]
411 from sklearn.neighbors import KNeighborsClassifier
412 [18]
413 knn_clf = KNeighborsClassifier(n_neighbors=3)
414 [19]
415 knn_clf.fit(X_train,y_train)
416 KNeighborsClassifier(n_neighbors=3)
417 [20]
418 knn_clf.score(X_test_standard,y_test)
419 0.9777777777777777
420 [21]
421 knn_clf.score(X_test,y_test)
422 0.3333333333333333
复制代码

 

4-9 更多有关k近邻算法的思考

 

 

 

 

 

 

 

 

 

 

 

posted @   Cai-Gbro  阅读(232)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 三行代码完成国际化适配,妙~啊~
· .NET Core 中如何实现缓存的预热?
点击右上角即可分享
微信分享提示