前言:分类是机器学习中的重要的一种功能,在机器学习的研究历史中,诞生了大量的分类算法,而每种算法都有其优势和不足。
本文汇总了常用的分类算法及其实现方式,方便快速查询使用。(本文使用鸢尾花数据集,是三类别分类)
以下9种分类算法,使用相同的数据进行训练和测试,在测试集上的准确率(accuracy)分别为:
1.随机森林:100%
2.决策树:100%
3.K近邻:100%
4.支持向量机:100%
5.逻辑回归:96.67%
6.线性支持向量机:100%
7.随机梯度下降:96.67%
8.感知机:100%
9.朴素贝叶斯:96.67%
1 import numpy as np 2 import pandas as pd 3 import matplotlib as mpl 4 import matplotlib.pyplot as plt 5 import sklearn 6 from sklearn import datasets 7 from sklearn.metrics import accuracy_score 8 9 from sklearn.ensemble import RandomForestClassifier 10 from sklearn.tree import DecisionTreeClassifier 11 12 from sklearn.neighbors import KNeighborsClassifier 13 from sklearn.svm import SVC, LinearSVC 14 from sklearn.linear_model import LogisticRegression 15 16 from sklearn.linear_model import SGDClassifier 17 from sklearn.linear_model import Perceptron 18 from sklearn.naive_bayes import GaussianNB 19 20 from sklearn.model_selection import train_test_split 21 from sklearn.model_selection import cross_val_score 22 23 from sklearn.model_selection import GridSearchCV 24 25 iris = datasets.load_iris() 26 x,y = iris.data,iris.target 27 28 x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0) 29 30 res = [] 31 32 #1. 随机森林分类 33 print('随机森林分类') 34 clf = RandomForestClassifier(n_estimators=100) 35 clf.fit(x_train, y_train) 36 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy") 37 print(cross_score) 38 y_predict = clf.predict(x_test) 39 score = accuracy_score(y_test,y_predict) 40 res.append(score) 41 print() 42 43 #2. 决策树分类 44 print('决策树分类') 45 clf = DecisionTreeClassifier() 46 clf.fit(x_train, y_train) 47 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy") 48 print(cross_score) 49 y_predict = clf.predict(x_test) 50 score = accuracy_score(y_test,y_predict) 51 res.append(score) 52 print() 53 54 #3. KNN 55 print('KNN') 56 clf = KNeighborsClassifier(n_neighbors = 13) 57 clf.fit(x_train, y_train) 58 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy") 59 print(cross_score) 60 y_predict = clf.predict(x_test) 61 score = accuracy_score(y_test,y_predict) 62 res.append(score) 63 print() 64 65 #4. SVM分类 66 print('SVM') 67 clf = SVC(gamma='scale') 68 clf.fit(x_train, y_train) 69 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy") 70 print(cross_score) 71 y_predict = clf.predict(x_test) 72 score = accuracy_score(y_test,y_predict) 73 res.append(score) 74 print() 75 76 #5. 逻辑回归分类 77 print('LogisticRegression') 78 clf = LogisticRegression(solver='lbfgs',multi_class='ovr') 79 clf.fit(x_train, y_train) 80 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy") 81 print(cross_score) 82 y_predict = clf.predict(x_test) 83 score = accuracy_score(y_test,y_predict) 84 res.append(score) 85 print() 86 87 #6. linear svm分类 88 print('linear SVM') 89 clf = LinearSVC(max_iter=10000) 90 clf.fit(x_train, y_train) 91 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy") 92 print(cross_score) 93 y_predict = clf.predict(x_test) 94 score = accuracy_score(y_test,y_predict) 95 res.append(score) 96 print() 97 98 #7. 随机梯度下降分类 99 print('SGD') 100 clf = SGDClassifier(max_iter=1000,tol=1e-3) 101 clf.fit(x_train, y_train) 102 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy") 103 print(cross_score) 104 y_predict = clf.predict(x_test) 105 score = accuracy_score(y_test,y_predict) 106 res.append(score) 107 print() 108 109 #8. 感知机分类 110 print('Perceptron') 111 clf = Perceptron(max_iter=1000,tol=1e-3) 112 clf.fit(x_train, y_train) 113 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy") 114 print(cross_score) 115 y_predict = clf.predict(x_test) 116 score = accuracy_score(y_test,y_predict) 117 res.append(score) 118 print() 119 120 #9. 朴素贝叶斯分类 121 print('Naive Bayes') 122 clf = GaussianNB() 123 clf.fit(x_train, y_train) 124 cross_score = cross_val_score(clf, x_train, y_train, cv=3, scoring="accuracy") 125 print(cross_score) 126 y_predict = clf.predict(x_test) 127 score = accuracy_score(y_test,y_predict) 128 res.append(score) 129 print() 130 131 #10. 得分比较 132 print(res)