(1)、cv::ml::Knearest类:继承自cv::ml::StateModel,而cv::ml::StateModel又继承自cv::Algorithm;

(2)、create函数:为static,new一个KNearestImpl用来创建一个KNearest对象;

(3)、setDefaultK/getDefaultK函数:在预测时,设置/获取的K值;

(4)、setIsClassifier/getIsClassifier函数:设置/获取应用KNN是进行分类还是回归;

(5)、setEmax/getEmax函数:在使用KDTree算法时,设置/获取Emax参数值;

(6)、setAlgorithmType/getAlgorithmType函数:设置/获取KNN算法类型,目前支持两种:brute_force和KDTree;

(7)、findNearest函数:根据输入预测分类/回归结果。

 

复制代码
#include<iostream>
#include <opencv2\opencv.hpp>
using namespace cv;
using namespace std;

int main()
{
    Mat img = imread("1.png");
    Mat gray;
    cvtColor(img, gray, CV_BGR2GRAY);
    threshold(gray, gray, 0, 255, CV_THRESH_BINARY);
    // digits.png为2000 * 1000,其中每个数字的大小为20 * 20,
    // 总共有5000((2000*1000) / (20*20))个数字,类型为[0~9],
    // [0~9]10个数字每个数字有5000/10 = 500个样本
    // 对其分割成单个20 * 20的图像并序列化成(转化成一个一维的数组)
    int side = 20;
    int m = gray.rows / side;
    int n = gray.cols / side;
    Mat data, labels;
    for (int i = 0; i < m; i++) {

        int offsetRow = i * side;
        for (int j = 0; j < n; j++) {

            int offsetCol = j * side;
            // 截取20*20的小块


            Mat tmp;

            gray(Range(offsetRow, offsetRow + side), Range(offsetCol, offsetCol + side)).copyTo(tmp);

            data.push_back(tmp.reshape(0, 1));  // 序列化转换成一个一维向量
            labels.push_back(i / 5);           // 每500个为一个label类型            
        }
    }
    data.convertTo(data, CV_32F);
    cout << "读取结束..." << endl;
    //****************** 使用KNN算法训练********************//
    int K = 7;    // 改变K值可能会出现不同的效果,K值越大,识别速度越慢
    //为static,new一个ModelKnn用来创建一个KNearest对象;
    Ptr<ml::KNearest> ModelKnn = ml::KNearest::create();
    ModelKnn->setDefaultK(K);
    ModelKnn->setIsClassifier(true);
    //ModelKnn->train(data, cv::ml::ROW_SAMPLE, labels);
    Ptr<ml::TrainData> tData = ml::TrainData::create(data, ml::ROW_SAMPLE, labels);
    ModelKnn->train(tData);
    ModelKnn->save("KnnTest.xml");
    ///********************测试模型***************************///
    Mat test = imread(".\\test\\3.jpg", 0);//截取图像中一个数字
    Mat bw;
    threshold(test, bw, 0, 255, CV_THRESH_BINARY);
    Mat I0 = bw.reshape(0, 1);
    I0.convertTo(I0, CV_32F);
    // 开始用KNN预测分类,返回识别结果
    float r = ModelKnn->predict(I0);

}
复制代码