最近使用sklearn跑一些机器学习的实验对比,发现许多算法随着数据集增大,训练时间呈几何增加,加之交叉验证、参数选择等,非常耗时。
对此,已经有许多优化方案被提出。这里给出一个关于K-NN分类算法的快速实现工具推荐:基于faiss实现版本, 亲测速度提升明显
(1) 安装GPU版本faiss (https://pypi.org/project/faiss-gpu/)
pip install faiss-gpu
(2) 安装支持KNN的faiss wapper的包DESlib
DESlib 是一个集成学习库,类似sklearn,并提供了sklearn的基本一致接口,专注于动态分类器和集成选择的最新技术的实现。
pip install deslib
(3) 使用样例,可以参考接口说明使用:https://deslib.readthedocs.io/en/latest/modules/util/faiss_knn_wrapper.html
from deslib.util.faiss_knn_wrapper import FaissKNNClassifier clf = FaissKNNClassifier(n_neighbors=5, n_jobs=10, algorithm='brute', n_cells=100, n_probes=2) clf.fit(X_train, y_train) #训练 y_test_proba = clf.predict_proba(X_test) #预测概率
本人进过测试, 训练速度比sklearn中的KNN实现快100倍以上。
satellite(sklearn): 1867s vs satellite(faiss ):8s
更多使用和测试参考:
import time import matplotlib.pyplot as plt from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier from deslib.util.faiss_knn_wrapper import FaissKNNClassifier n_samples = [1000, 10000, 100000, 1000000, 10000000] rng = 42 faiss_brute = FaissKNNClassifier(n_neighbors=7, algorithm='brute') faiss_voronoi = FaissKNNClassifier(n_neighbors=7, algorithm='voronoi') faiss_hierarchical = FaissKNNClassifier(n_neighbors=7, algorithm='hierarchical') all_knns = [faiss_brute, faiss_voronoi, faiss_hierarchical] names = ['faiss_brute', 'faiss_voronoi', 'faiss_hierarchical'] list_fitting_time = [] list_search_time = [] for n in n_samples: print("Number of samples: {}" .format(n)) X, y = make_classification(n_samples=n, n_features=20, random_state=rng) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) temp_fitting_time = [] temp_search_time = [] for name, knn in zip(names, all_knns): start = time.clock() knn.fit(X_train, y_train) fitting_time = time.clock() - start print("{} fitting time: {}" .format(name, fitting_time)) start = time.clock() neighbors, dists = knn.kneighbors(X_test) search_time = time.clock() - start print("{} neighborhood search time: {}" .format(name, search_time)) temp_fitting_time.append(fitting_time) temp_search_time.append(search_time) list_fitting_time.append(temp_fitting_time) list_search_time.append(temp_search_time) plt.plot(n_samples, list_search_time) plt.legend(names) plt.xlabel("Number of samples") plt.ylabel("K neighbors search time") plt.savefig('knn_backbone_benchmark.png')
其他测试可以参考博文:
https://towardsdatascience.com/make-knn-300-times-faster-than-scikit-learns-in-20-lines-5e29d74e76bb