2/8 iris_data_analysis
下午看了下lynda关系型数据库的课,instructor很有激情但是废话太多了,中途弃。还是找点tutorial pdf来看吧orz。
晚上简单做了个鸢尾花数据分析的Jupyter notebook,用的模型是K Nearest Neighbor。
from sklearn.datasets import load_iris from sklearn.cross_validation import train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn import metrics from sklearn.cross_validation import cross_val_score import numpy as np iris = datasets.load_iris() X = iris.data[:, :4] y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=4)#X=feature y=label knn = KNeighborsClassifier(n_neighbors=5) knn.fit(X_train, y_train) y_pred = knn.predict(X_test) print(metrics.accuracy_score(y_test, y_pred)) k_range = list(range(1, 31)) k_scores = [] for k in k_range: knn = KNeighborsClassifier(n_neighbors=k) scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy') k_scores.append(scores.mean()) print(k_scores) plt.plot(k_range, k_scores) plt.xlabel('Value K for KNN') plt.ylabel('Cross_val_score')
最后的plot长这样:
可见K=20的时候 cross validation 得分最高0.98且模型最简单。(K越大模型越简单)