SVM的使用train()

注意:数据结构的一致性,在高维度数据一般使用rbf核函数,使用网格搜索思想迭代求出gamma和c。

每行为一个样本,数据类型都围绕标黄代码而定义的。

SVM训练如下坐标(左边一列为A类,右边为B类),然后预测给出的坐标属于哪一类。

#include<opencv2\opencv.hpp>
#include<iostream>
#include<opencv2\ml.hpp> //引入机器学习
using namespace cv;
using namespace std;
using namespace ml;

int main()
{
    //*1、类别标签labelsMat,因为其是短整型,所以labels定义成int类型。最后再转回char
    int labels[14] = { 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B', 'B', 'B' };
    Mat labelsMat(14, 1, CV_32S);//短整型
    for (int i = 0; i < labelsMat.rows; i++)
    {
        labelsMat.at<int>(i, 0) = labels[i];
    }
    //*2、用于训练的样本集trainingDataMat
    int trainingData[14][2] = { { 110, 204 }, { 105, 306 }, { 102, 410 }, { 99, 511 }, { 93, 610 }, { 89, 713 }, { 89, 817 },
    { 173, 208 }, { 175, 313 }, { 167, 415 }, { 163, 514 }, { 160, 612 }, { 156, 716 }, { 152, 819 } };
    Mat trainingDataMat(14, 2, CV_32F); //float类型
    for (int i = 0; i < trainingDataMat.rows; i++)
    {
        for (int j = 0; j < trainingDataMat.cols; j++)
        {
            trainingDataMat.at<float>(i, j) = trainingData[i][j];
        }
    }
    //*3、初始化SVM,参数参考 https://blog.csdn.net/qq_27278957/article/details/88736516
    Ptr<ml::SVM> svm = ml::SVM::create();
    svm->setType(SVM::C_SVC); //svm的类型,
    svm->setKernel(SVM::LINEAR); //核函数
    svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, FLT_EPSILON)); //终止条件
    //*4、训练模型
    Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);//训练样本的数据类型必须是CV_32F,标签可以是CV_32S或其他。
    svm->train(tData);
    svm->save("svmData.xml");
    //*5、预测
    Mat tmp(1, 2, CV_32F);
    tmp.at<float>(0, 0) = 163;
    tmp.at<float>(0, 1) = 600;

    char label = (char)svm->predict(tmp); //ASCII码转字符,预测结果为B
    cout << label << endl;

    waitKey(0);
    return 0;
}

上图绘制代码:

Mat plot(900, 900, CV_8U);
vector<Point> myPoint(14);//14个点
for (int i = 0; i < myPoint.size(); i++)
{
    myPoint[i].x = trainingData[i][0];
    myPoint[i].y = trainingData[i][1];
    circle(plot, myPoint[i], 15, Scalar(255), -1);
}
namedWindow("坐标点", 0);
imshow("坐标点", plot);

 【参考】

https://blog.csdn.net/bigFatCat_Tom/article/details/95201903?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task

posted @ 2020-02-28 11:15  夕西行  阅读(1585)  评论(0编辑  收藏  举报