Sklearn保存和导入模型,pickle.dump()
Python中Pickle模块的dump()方法和load()方法
在机器学习中,我们常常需要把训练好的模型存储起来,这样在进行决策时直接将模型读出,而不需要重新训练模型,这样就大大节约了时间。Python提供的pickle模块就很好地解决了这个问题,它可以序列化对象并保存到磁盘中,并在需要的时候读取出来,任何对象都可以执行序列化操作。
Pickle模块中最常用的函数为:
(1)pickle.dump(obj, file, [,protocol])
函数的功能:将obj对象序列化存入已经打开的file中。
参数讲解:
obj:想要序列化的obj对象。
file:文件名称。
protocol:序列化使用的协议。如果该项省略,则默认为0。如果为负值或HIGHEST_PROTOCOL,则使用最高的协议版本。
(2)pickle.load(file)
函数的功能:将file中的对象序列化读出。
参数讲解:
file:文件名称。
(3)pickle.dumps(obj[, protocol])
函数的功能:将obj对象序列化为string形式,而不是存入文件中。
参数讲解:
obj:想要序列化的obj对象。
protocal:如果该项省略,则默认为0。如果为负值或HIGHEST_PROTOCOL,则使用最高的协议版本。
(4)pickle.loads(string)
函数的功能:从string中读出序列化前的obj对象。
参数讲解:
string:文件名称。
【注】 dump() 与 load() 相比 dumps() 和 loads() 还有另一种能力:dump()函数能一个接着一个地将几个对象序列化存储到同一个文件中,随后调用load()来以同样的顺序反序列化读出这些对象。
import matplotlib.pyplot as plt import pandas as pd from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler from sklearn.datasets import load_iris data=load_iris() print(data.feature_names,data.target_names) df=pd.DataFrame(data.data,columns=data.feature_names) df['class']=data.target print("************"*10) print(type(data),sep='\n') X,y=load_iris().data,load_iris().target data_pca=PCA(n_components=2).fit_transform(X) print(data_pca.shape) print(y.shape) print('y: ',y) # z=[z for z in range(3) for i in range(150)] plt.scatter(data_pca[:,0],data_pca[:,1],c=y)
#输出: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] ['setosa' 'versicolor' 'virginica'] ************************************************************************************************************************ <class 'sklearn.utils._bunch.Bunch'> (150, 2) (150,) y: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
from sklearn.datasets import load_iris X, y = load_iris().data, load_iris().target #或 from sklearn.datasets import load_boston boston = load_boston() x = boston.data y = boston.target #下载数据集 from sklearn.datasets import load_boston boston = load_boston() #查看数据键值 boston.keys() #dict_keys(['data', 'target', 'feature_names', 'DESCR', 'filename']) x = boston['data'] y = boston['target'] #查看数据结构 x.shape #(506, 13) #506个样本,13个属性
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2) #选择模型拟合训练集 from sklearn.linear_model import LinearRegression lr = LinearRegression() lr.fit(x_train,y_train) #使用测试集测试数据 y_ = lr.predict(x_test) #预测值与真实值的差距 (y_test-y_).round(2)
from sklearn import datasets import numpy as np import pandas as pd iris = datasets.load_iris() X=iris.data[:,:] y=iris.target print(X) print(X.shape,type(iris)) print(y) from sklearn.model_selection import train_test_split X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.33,random_state=3,shuffle=False)#默认shuffle为True,此处如 # 果用True效果比用False好,自己实验的一般来说随机种子用3比较好 print(X_train,X_test,y_train,y_test,sep='\n') from sklearn.preprocessing import StandardScaler sc=StandardScaler() # sc.fit(X_train) sc.fit(X) X_train_std=sc.transform(X_train) X_test_std=sc.transform(X_test) print(X_train.shape,y_train.shape,X_test.shape,y_test.shape,sep='\n') #拟合/训练 from sklearn.linear_model import Perceptron ppn=Perceptron(n_iter_no_change=50,eta0=0.1,random_state=3) ppn.fit(X_train_std,y_train) y_pred=ppn.predict(X_test_std)#用导入的模型重新预测 print(f"Misclassfied sampled: {(y_test != y_pred).sum()}") # 法1. 保存和导入训练好的模型 import pickle f=pickle.dumps(ppn)#保存模型 ppn1=pickle.loads(f)#导入模型 y_pred1=ppn1.predict(X_test_std)#用导入的模型重新预测 print(f"Misclassfied sampled: {(y_test != y_pred1).sum()}") # 法2. 保存和导入训练好的模型 from joblib import dump,load dump(ppn,'./ppn.joblib')#当然也可以保存成其他后缀,比如pth ppn2=load('ppn.joblib')# y_pred2=ppn2.predict(X_test_std) print(f"Misclassfied sampled: {(y_test != y_pred2).sum()}") # ## sklearn 示例 # from sklearn import svm # from sklearn import datasets # clf = svm.SVC() # X, y= datasets.load_iris(return_X_y=True) # clf.fit(X, y) # # import pickle # s = pickle.dumps(clf) # clf2 = pickle.loads(s) # clf2.predict(X[0:1]) # print(y[0]) # # from joblib import dump, load # dump(clf, 'filename.joblib') # clf = load('filename.joblib') # print(clf)
#打印出准确率 from sklearn.metrics import accuracy_score print(f"Accuracy: {accuracy_score(y_test,y_pred):.2f}")
pickle.dump()
封装是一个将Python数据对象转化为字节流的过程,拆封是封装的逆操作,将字节文件或字节对象中的字节流转化为Python数据对象,不要从不收信任的数据源中拆封数据。可以封装和拆封几乎任何Python数据对象,主要包括:
None , True,False
整数,浮点数,复数
字符串,字节,ByteArray对象
元组,列表,集合,包含可封装对象的字典
在一个模块的顶层定义的函数
在一个模块的顶层定义的内置函数
那是在一个模块的顶层定义的类
__dict__或调用__getstate__()的结果是可封装的类的实例
pickle模块中常用的方法有:
1. pickle.dump(obj, file, protocol=None,)
必填参数obj表示将要封装的对象
必填参数file表示obj要写入的文件对象,file必须以二进制可写模式打开,即“wb”
可选参数protocol表示告知pickler使用的协议,支持的协议有0,1,2,3,默认的协议是添加在Python 3中的协议3, 其他的协议详情见参考文档
2. pickle.load(file,*,fix_imports=True, encoding="ASCII", errors="strict")
必填参数file必须以二进制可读模式打开,即“rb”,其他都为可选参数
3. pickle.dumps(obj):以字节对象形式返回封装的对象,不需要写入文件中
4. pickle.loads(bytes_object): 从字节对象中读取被封装的对象,并返回
pickle模块可能出现三种异常:
1. PickleError:封装和拆封时出现的异常类,继承自Exception
2. PicklingError: 遇到不可封装的对象时出现的异常,继承自PickleError
3. UnPicklingError: 拆封对象过程中出现的异常,继承自PickleError
pickle应用实例:
import pickle with open("my_profile.txt", "wb") as myprofile: pickle.dump({"name":"AlwaysJane", "age":"20+", "sex":"female"}, myprofile) with open("my_profile.txt", "rb") as get_myprofile: print (pickle.load(get_myprofile))
import pickle class Profile: name = "AlwaysJane" pickledclass = pickle.dumps(Profile) print (pickledclass) print (pickle.loads(pickledclass))
基础知识:
python自带的file函数只能存储和读取字符串格式的数据.
pickle可以存储和读取成其他格式比如list dict的数据,
pickle的功能就是把你上次计算得到的数据保存起来,当你需要使用这些数据时,直接通过reload把数据恢复了就行,这样的好处有:
1.被pickle的数据,在被多次reload时,不需要重新去计算得到这些数据,这样节省计算机资源,如果你不pickle,你每调用一次数据,就要计算一次。
2通过pickle的数据,被reload时,可以更好的被内存调用,不需要经过数据格式的转换。
有人可能觉得,我直接通过open把数据写到一个txt文档也能达到以上的效果,但是这样做的结果是,你能够达到pickle的功能,把数据保存起来,但是当你再去调用这些数据时,你的txt格式的数据,没有pickle的数据读取更高效。
另外还有一点,你通过open把数据存储到txt中时的效率,就不如pickle的效率高。
综上,你如果只是做一次的数据存储和调用,以及数据量很小的情况下,你可以用open等方法保存数据和调用数据,但是当你需要通过大量计算得到一个数据,同时后期还会多次使用这个数据时,pickle的节省计算机资源的效果就出来了。