sklearn使用SGD/Random Forest二分类识别手写数字5
from sklearn.datasets import fetch_openml # 获取数据 mnist = fetch_openml('mnist_784', version=1, as_frame=False,parser='pandas') X, y = mnist["data"], mnist["target"]
import numpy as np print(y[0]) # 字符串转整数 y = y.astype(np.uint8) # 总数7w张图片 # 六万张图片训练 # 六万张图片验证 X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:] # 标签转化为true or false y_train_5 = (y_train == 5) y_test_5 = (y_test == 5)
# 训练模型 from sklearn.linear_model import SGDClassifier sgd_clf = SGDClassifier(random_state=42) sgd_clf.fit(X_train, y_train_5)
# 随机森林 from sklearn.ensemble import RandomForestClassifier forest_clf = RandomForestClassifier(n_estimators=100, random_state=42) forest_clf.fit(X_train, y_train_5) # 预测手写体 forest_clf.predict([X[2028]])
# 预测手写体 sgd_clf.predict([X[2028]])
y[2028]
# 保存模型 import pickle with open('image_5.pickle', 'wb') as fw: pickle.dump(sgd_clf, fw)
# 加载模型 with open('image_5.pickle', 'rb') as fr: new_svm = pickle.load(fr) print(new_svm.predict([X[2028]]))
# 模型得分 from sklearn.model_selection import cross_val_score cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
# 混淆矩阵 from sklearn.model_selection import cross_val_predict y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3) from sklearn.metrics import confusion_matrix confusion_matrix(y_train_5, y_train_pred)
# 预测精度 from sklearn.metrics import precision_score, recall_score precision_score(y_train_5, y_train_pred)
# 召回率 recall_score(y_train_5, y_train_pred)
# f1分数 from sklearn.metrics import f1_score f1_score(y_train_5, y_train_pred)