sklearn中数据缩放用到的fit_transform()、transform()、fit()方法的区别与联系
看了一堆搜索排名靠前的中文博客,感觉没有一个解释能让人醍醐灌顶的,故搜索英文网页并记之。
谢绝转载。
首先对于数据标准化一般是这么做的:
其中μ是均值, σ是标准差。目的是使数据服从均值为零,标准差为1的标准正态分布,此即标准化(Standardization)。
标准化都是给训练集数据做的,但在以下情况中也必须做数据标准化,比如,交叉验证时的测试集,或者是预测前获得了一组新的样本。而在对新的数据或测试集进行标准化时,我们所用的是训练集标准化中的均值μ和标准差σ。
因此,StandardScaler 中的fit()所做的就是计算数据的均值μ和标准差σ,并将他们储存为一个内部对象的状态,无返回值。然后,对测试集调用transform()方法,此方法将使用刚刚fit()计算得到的均值μ和标准差σ来对测试集数据进行标准化。
而fit_transform()就是将以上两步二合一,因为其内部就是先后调用fit()和transform()函数的。
所以我们经常能看到类似这样的代码:
1 # Feature Scaling 2 from sklearn.preprocessing import StandardScaler 3 sc = StandardScaler() 4 X_train = sc.fit_transform(X_train) 5 X_test = sc.transform(X_test)
注意这里fit_transform()是用在训练集上的,也就是说,fit_transform()先计算了训练集数据的均值μ和标准差σ,并以此对训练集进行标准化。
参考:
https://datascience.stackexchange.com/questions/12321/whats-the-difference-between-fit-and-fit-transform-in-scikit-learn-models
https://www.kaggle.com/questions-and-answers/58368