opencv笔记--Kmeans

    在图像分割中,使用 kmeans 算法可以实现图像区域基本分割。如果一幅图像被分为两类,kmeans 分割效果与  ostu 算法基本一致,具体如下图:

      

    

    kmeans 将图像灰度聚类为 k 类,ostu 将图像灰度分割为 2 类,当 k = 2 时,两种算法最终目的基本趋于一致。

    kmeans 算法基本思路如下:

    1)随机选取第一个聚类中心点,之后的聚类中心点选取有两种方法;

         a. 随机选取其他 k - 1 个聚类中心点;

         b. 根据已经选取的聚类中心点,计算所有点到已经选取的聚类中心点的距离,选择到所有已经选取的聚类中心点的最远点作为下一个聚类中心点;

    2)根据点到已经选取的聚类中心点的距离对其进行分类;

    3)重新求各个分类的聚类中心点,然后回到 2);

    4)当不再满足迭代条件时给出最终聚类结果,迭代条件包括:

          a. 聚类中心点在迭代过程中的偏移量;

          b. 迭代次数;

    对于聚类中心点的选择,一般情况下,方法 b 会得到更好的聚类,且迭代速度较快。

    opencv 提供的 kmean 函数为:

    double kmeans( InputArray data, int K, InputOutputArray bestLabels, TermCriteria criteria, int attempts,

                              int flags, OutputArray centers=noArray() );

    参数如下:

    data: 待分类点矩阵,其类型必须为 CV_32F;

     K,bestLabels: 聚类数与待分类点所属分类;

     criteria:停止条件;

     attempts:使用不同的随机聚类中心点尝试聚类次数;

     flags:聚类中心点选择方案,包括完全随机选择,kmeans++选择方案(b),用户输入;

     centers:最终聚类中心点;

     以下给出 kmeans 算法使用代码:

     

 1 void UseKmeans(cv::Mat& src, cv::Mat& rst)
 2 {
 3     int width = src.cols;
 4     int height = src.rows;
 5     int dims = src.channels();
 6     int sampleCount = width * height;
 7 
 8     int clusterCount = 2;
 9     Mat points(sampleCount, dims, CV_32F, Scalar(10));
10     cv::Mat pos(sampleCount, 2, CV_16S, Scalar(0, 0));
11     Mat labels;
12     Mat centers(clusterCount, 1, points.type());
13 
14     // invert to data points
15     int index = 0;
16     for (int row = 0; row < height; row++) {
17         for (int col = 0; col < width; col++) {
18             points.at<float>(index, 0) = static_cast<int>(src.ptr<uchar>(row)[col]);
19             pos.at<short>(index, 0) = static_cast<short>(row);
20             pos.at<short>(index, 1) = static_cast<int>(col);
21             ++index;
22         }
23     }
24 
25     // k-mean algorithm
26     TermCriteria criteria = TermCriteria(CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 100, 1.0);
27     kmeans(points, clusterCount, labels, criteria, 3, KMEANS_PP_CENTERS, centers);
28 
29     int bright_val = -1;
30     for (int i = 0; i < centers.rows; ++i)
31     {
32         int val = centers.at<float>(i, 0);
33         if (val > bright_val)
34             bright_val = val;
35     }
36 
37     int bright_label = -1;
38     for (int idx = 0; idx < sampleCount; ++idx)
39     {
40         float *datapoint = points.ptr<float>(idx);
41         int *datalabel = labels.ptr<int>(idx);
42         if (datapoint[0] >= bright_val)
43         {
44             bright_label = datalabel[0];
45             break;
46         }
47     }
48 
49     // save result
50     rst.create(src.size(), CV_8UC1);
51     rst.rowRange(0, rst.rows) = 0;
52     for (int idx = 0; idx < sampleCount; ++idx)
53     {
54         int *datalabel = labels.ptr<int>(idx);
55         if (datalabel[0] == bright_label)
56         {
57             int row = pos.at<short>(idx, 0);
58             int col = pos.at<short>(idx, 1);
59             rst.ptr<uchar>(row)[col] = 255;
60         }
61     }
62 }

 

     

     

     

     

   

posted @ 2020-10-28 12:27  罗飞居  阅读(346)  评论(0编辑  收藏  举报