学习笔记155—机器学习之分类器——Matlab中各种分类器的使用总结(随机森林、支持向量机、K近邻分类器、朴素贝叶斯等)
Matlab中常用的分类器有随机森林分类器、支持向量机(SVM)、K近邻分类器、朴素贝叶斯、集成学习方法和鉴别分析分类器等。各分类器的相关Matlab函数使用方法如下:
首先对以下介绍中所用到的一些变量做统一的说明:
train_data——训练样本,矩阵的每一行数据构成一个样本,每列表示一种特征
train_label——训练样本标签,为列向量
test_data——测试样本,矩阵的每一行数据构成一个样本,每列表示一种特征
test_label——测试样本标签,为列向量
①随机森林分类器(Random Forest)
TB=TreeBagger(nTree,train_data,train_label);
predict_label=predict(TB,test_data);
②支持向量机(Support Vector Machine,SVM)
SVMmodel=svmtrain(train_data,train_label);
predict_label=svmclassify(SVMmodel,test_data);
③K近邻分类器(KNN)
KNNmodel=ClassificationKNN.fit(train_data,train_label,'NumNeighbors',1);
predict_label=predict(KNNmodel,test_data);
④朴素贝叶斯(Naive Bayes)
Bayesmodel=NaiveBayes.fit(train_data,train_label);
predict_label=predict(Bayesmodel,test_data);
⑤集成学习方法(Ensembles for Boosting)
Bmodel=fitensemble(train_data,train_label,'AdaBoostM1',100,'tree','type','classification');
predict_label=predict(Bmodel,test_data);
⑥鉴别分析分类器(Discriminant Analysis Classifier)
DACmodel=ClassificationDiscriminant.fit(train_data,train_label);
predict_label=predict(DACmodel,test_data);
具体使用如下:(练习数据下载地址如下http://en.wikipedia.org/wiki/Iris_flower_data_set,简单介绍一下该数据集:有一批花可以分为3个品种,不同品种的花的花萼长度、花萼宽度、花瓣长度、花瓣宽度会有差异,根据这些特征实现品种分类)
%% 随机森林分类器(Random Forest)
nTree=10;
B=TreeBagger(nTree,train_data,train_label,'Method', 'classification');
predictl=predict(B,test_data);
predict_label=str2num(cell2mat(predictl));
Forest_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% 支持向量机
% SVMStruct = svmtrain(train_data, train_label);
% predictl=svmclassify(SVMStruct,test_data);
% predict_label=str2num(cell2mat(predictl));
% SVM_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% K近邻分类器(KNN)
% mdl = ClassificationKNN.fit(train_data,train_label,'NumNeighbors',1);
% predict_label=predict(mdl, test_data);
% KNN_accuracy=length(find(predict_label == test_label))/length(test_label)*100
%% 朴素贝叶斯 (Naive Bayes)
% nb = NaiveBayes.fit(train_data, train_label);
% predict_label=predict(nb, test_data);
% Bayes_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% 集成学习方法(Ensembles for Boosting, Bagging, or Random Subspace)
% ens = fitensemble(train_data,train_label,'AdaBoostM1' ,100,'tree','type','classification');
% predictl=predict(ens,test_data);
% predict_label=str2num(cell2mat(predictl));
% EB_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% 鉴别分析分类器(discriminant analysis classifier)
% obj = ClassificationDiscriminant.fit(train_data, train_label);
% predictl=predict(obj,test_data);
% predict_label=str2num(cell2mat(predictl));
% DAC_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% 练习
% meas=[0 0;2 0;2 2;0 2;4 4;6 4;6 6;4 6];
% [N n]=size(meas);
% species={'1';'1';'1';'1';'-1';'-1';'-1';'-1'};
% ObjBayes=NaiveBayes.fit(meas,species);
% x=[3 3;5 5];
% result=ObjBayes.predict(x);
参考链接:https://blog.csdn.net/jisuanjiguoba/java/article/details/80004568
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)