基于SGD、ASGD算法的SVM分类器(OpenCV案例源码train_svmsgd.cpp解读)
此案例用于二分类问题(鼠标左键、右键点出两类点,会实时画出分界线),最终得到一条分界线(直线):f(x)=weights*x+shift
源码不再贴出,只讲解最核心的doTrain()里的内容。参数含义翻译自ml.hpp文件。
与SVM不同,SVMSGD不需要设置核函数。
【参数】默认值见下述代码
模型类型:SGD、ASGD(推荐)。随机梯度下降、平均随机梯度下降。
边界类型:HARD_MARGIN、SOFT_MARGIN(推荐),前者用于线性可分,后者用于非线性可分
边界规范化 lambda:推荐设为0.0001(对于SGD),0.00001(对于ASGD)。越小,异类被抛弃的越少。
步长 gamma_0
步长降低力度 c:推荐设置为1(对于SGD),0.75(对于ASGD)
终止条件:TermCriteria::COUNT、TermCriteria::EPS、TermCriteria::COUNT + TermCriteria::EPS
参数设置函数:
setSvmsgdType()
setMarginType()
setMarginRegularization()
setInitialStepSize()
setStepDecreasingPower()
【使用方式】
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();//创建对象
svmsgd->train(trainData);//训练
svmsgd->save("MySvmsgd.xml");//保存模型
svmsgd->load("MySvmsgd.xml");//加载模型
svmsgd->predict(samples, responses);//预测,结果保存到responses标签中
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift) { //*创建SVMSGD对象 cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); //创建SVMSGD对象 //*设置参数,以下全是默认参数 //svmsgd->setSvmsgdType(SVMSGD::ASGD); //模型类型 //svmsgd->setMarginType(SVMSGD::SOFT_MARGIN); //边界类型 //svmsgd->setMarginRegularization(0.00001); //边界规范化 //svmsgd->setInitialStepSize(0.05);//步长 //svmsgd->setStepDecreasingPower(0.75); //步长减弱力度 //svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT,1000,1e-3));//终止条件,1000次迭代,0.001每次迭代的精度 //*训练集 cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses); //*训练 svmsgd->train(trainData); if (svmsgd->isTrained()) //获取分界线的系数,f(x)=weights*x+shift { weights = svmsgd->getWeights();//x系数 shift = svmsgd->getShift();//常数项 //*保存模型 svmsgd->save("svmsgd.xml"); //保存训练好的模型 return true; } return false; }
得到的xml中,weights有两个数,shift有一个数。
f(x)=weights*x+shift,不可以理解为y=kx+b,应该理解为Ax+By+C=0。weights的两个数就是A、B,shift是C。
Mat weights(1, 2, CV_32FC1); weights是一个1*2的向量,x也是1*2的向量(xi,xj)也就是(x,y)坐标。
公式写全了就是:f(x)=weights1*xi+weights2*xj+shift,其实就是weights与x这两个向量的内积(对应相乘在求和)
f(x)如果等于0,说明点在此直线上,大于0就在线的一边,小于0在线的另一边。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
2017-03-06 描述性统计量