sklearn使用SGD/Random Forest多分类识别手写数字

from sklearn.datasets import fetch_openml
# 获取数据
mnist = fetch_openml('mnist_784', version=1, as_frame=False,parser='pandas')
X, y = mnist["data"], mnist["target"]
import numpy as np
# 字符串转整数
y = y.astype(np.uint8)
# 总数7w张图片
# 六万张图片训练
# 一万张图片验证
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
# 训练模型
from sklearn.linear_model import SGDClassifier
# 随机梯度下降(SGD)
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train)
# 随机森林
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)
forest_clf.fit(X_train, y_train)
# 预测
print(y[0],y[1])
forest_clf.predict([X[0],X[1]])
# 分类分数
forest_clf.predict_proba([X[0]])
# 预测
print(y[0],y[1])
sgd_clf.predict([X[0],X[1]])
# 分类分数
sgd_clf.decision_function([X[0]])
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

some_digit_image = X[2028].reshape(28, 28)
plt.imshow(some_digit_image,cmap='binary')
plt.axis("off")

plt.show()
# 模型得分
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
# 混淆矩阵
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(forest_clf, X_train, y_train, cv=3)

from sklearn.metrics import confusion_matrix
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
plt.matshow(conf_mx, cmap='gray')
plt.show()

 

posted @ 2023-03-15 16:10  缘故为何  阅读(82)  评论(0编辑  收藏  举报