sklearn使用SGD/Random Forest二分类识别手写数字5

from sklearn.datasets import fetch_openml
# 获取数据
mnist = fetch_openml('mnist_784', version=1, as_frame=False,parser='pandas')
X, y = mnist["data"], mnist["target"]
import numpy as np
print(y[0])
# 字符串转整数
y = y.astype(np.uint8)
# 总数7w张图片
# 六万张图片训练
# 六万张图片验证
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

# 标签转化为true or false
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
# 训练模型
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
# 随机森林
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)
forest_clf.fit(X_train, y_train_5)
# 预测手写体
forest_clf.predict([X[2028]])
# 预测手写体
sgd_clf.predict([X[2028]])
y[2028]
# 保存模型
import pickle
with open('image_5.pickle', 'wb') as fw:
    pickle.dump(sgd_clf, fw)
# 加载模型
with open('image_5.pickle', 'rb') as fr:
    new_svm = pickle.load(fr)
    print(new_svm.predict([X[2028]]))
# 模型得分
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
# 混淆矩阵
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)

from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)
# 预测精度
from sklearn.metrics import precision_score, recall_score
precision_score(y_train_5, y_train_pred)
# 召回率
recall_score(y_train_5, y_train_pred)
# f1分数
from sklearn.metrics import f1_score
f1_score(y_train_5, y_train_pred)

 

posted @ 2023-03-14 15:22  缘故为何  阅读(44)  评论(0编辑  收藏  举报