SVM的使用train()

注意:数据结构的一致性,在高维度数据一般使用rbf核函数,使用网格搜索思想迭代求出gamma和c。

每行为一个样本,数据类型都围绕标黄代码而定义的。

SVM训练如下坐标(左边一列为A类,右边为B类),然后预测给出的坐标属于哪一类。

复制代码
#include<opencv2\opencv.hpp>
#include<iostream>
#include<opencv2\ml.hpp> //引入机器学习
using namespace cv;
using namespace std;
using namespace ml;

int main()
{
    //*1、类别标签labelsMat,因为其是短整型,所以labels定义成int类型。最后再转回char
    int labels[14] = { 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B', 'B', 'B' };
    Mat labelsMat(14, 1, CV_32S);//短整型
    for (int i = 0; i < labelsMat.rows; i++)
    {
        labelsMat.at<int>(i, 0) = labels[i];
    }
    //*2、用于训练的样本集trainingDataMat
    int trainingData[14][2] = { { 110, 204 }, { 105, 306 }, { 102, 410 }, { 99, 511 }, { 93, 610 }, { 89, 713 }, { 89, 817 },
    { 173, 208 }, { 175, 313 }, { 167, 415 }, { 163, 514 }, { 160, 612 }, { 156, 716 }, { 152, 819 } };
    Mat trainingDataMat(14, 2, CV_32F); //float类型
    for (int i = 0; i < trainingDataMat.rows; i++)
    {
        for (int j = 0; j < trainingDataMat.cols; j++)
        {
            trainingDataMat.at<float>(i, j) = trainingData[i][j];
        }
    }
    //*3、初始化SVM,参数参考 https://blog.csdn.net/qq_27278957/article/details/88736516
    Ptr<ml::SVM> svm = ml::SVM::create();
    svm->setType(SVM::C_SVC); //svm的类型,
    svm->setKernel(SVM::LINEAR); //核函数
    svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, FLT_EPSILON)); //终止条件
    //*4、训练模型
    Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);//训练样本的数据类型必须是CV_32F,标签可以是CV_32S或其他。
    svm->train(tData);
    svm->save("svmData.xml");
    //*5、预测
    Mat tmp(1, 2, CV_32F);
    tmp.at<float>(0, 0) = 163;
    tmp.at<float>(0, 1) = 600;

    char label = (char)svm->predict(tmp); //ASCII码转字符,预测结果为B
    cout << label << endl;

    waitKey(0);
    return 0;
}
复制代码

上图绘制代码:

复制代码
Mat plot(900, 900, CV_8U);
vector<Point> myPoint(14);//14个点
for (int i = 0; i < myPoint.size(); i++)
{
    myPoint[i].x = trainingData[i][0];
    myPoint[i].y = trainingData[i][1];
    circle(plot, myPoint[i], 15, Scalar(255), -1);
}
namedWindow("坐标点", 0);
imshow("坐标点", plot);
复制代码

 【参考】

https://blog.csdn.net/bigFatCat_Tom/article/details/95201903?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task

posted @   夕西行  阅读(1604)  评论(0编辑  收藏  举报
编辑推荐:
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
点击右上角即可分享
微信分享提示