引入包(注意joblib的引入,如果使用from sklearn.externals import joblib会报错“ImportError: cannot import name 'joblib”)
from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression import joblib
模型的保存
def model_save(): #X的shape是(150,4),y是个一维数组,长度为150,可能有3种标签 X, y = load_iris(return_X_y=True) #训练模型 clf = LogisticRegression(random_state=0).fit(X, y) joblib.dump(clf, "lr.pkl")
模型的加载和使用
def model_load(): clf = joblib.load("lr.pkl") #X的shape是(150,4),y是个一维数组,长度为150,可能有3种标签 X, y = load_iris(return_X_y=True) #处理前两行,所有列。predict,返回的是这两行数据的预测结果,shape为(2,) labels = clf.predict(X[:2, :]) print(labels.shape)