sklearn使用SGD/Random Forest多分类识别手写数字
from sklearn.datasets import fetch_openml # 获取数据 mnist = fetch_openml('mnist_784', version=1, as_frame=False,parser='pandas') X, y = mnist["data"], mnist["target"]
import numpy as np # 字符串转整数 y = y.astype(np.uint8) # 总数7w张图片 # 六万张图片训练 # 一万张图片验证 X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
# 训练模型 from sklearn.linear_model import SGDClassifier # 随机梯度下降(SGD) sgd_clf = SGDClassifier(random_state=42) sgd_clf.fit(X_train, y_train)
# 随机森林 from sklearn.ensemble import RandomForestClassifier forest_clf = RandomForestClassifier(n_estimators=100, random_state=42) forest_clf.fit(X_train, y_train)
# 预测 print(y[0],y[1]) forest_clf.predict([X[0],X[1]])
# 分类分数 forest_clf.predict_proba([X[0]])
# 预测 print(y[0],y[1]) sgd_clf.predict([X[0],X[1]])
# 分类分数 sgd_clf.decision_function([X[0]])
%matplotlib inline import matplotlib as mpl import matplotlib.pyplot as plt some_digit_image = X[2028].reshape(28, 28) plt.imshow(some_digit_image,cmap='binary') plt.axis("off") plt.show()
# 模型得分 from sklearn.model_selection import cross_val_score cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
# 混淆矩阵 from sklearn.model_selection import cross_val_predict y_train_pred = cross_val_predict(forest_clf, X_train, y_train, cv=3) from sklearn.metrics import confusion_matrix conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
plt.matshow(conf_mx, cmap='gray') plt.show()