1 #include <opencv2/opencv.hpp>
2 #include <iostream>
3
4 using namespace cv;
5 using namespace cv::ml;
6 using namespace std;
7
8 int main(int argc, char** argv) {
9 Mat src = imread("toux.jpg");
10 if (src.empty()) {
11 printf("could not load iamge...\n");
12 return -1;
13 }
14 namedWindow("input image", CV_WINDOW_AUTOSIZE);
15 imshow("input image", src);
16
17 // 初始化
18 int numCluster = 3;
19 const Scalar colors[] = {
20 Scalar(255, 0, 0),
21 Scalar(0, 255, 0),
22 Scalar(0, 0, 255),
23 Scalar(255, 255, 0)
24 };
25
26 int width = src.cols;
27 int height = src.rows;
28 int dims = src.channels();
29 int nsamples = width*height;
30 Mat points(nsamples, dims, CV_64FC1);
31 Mat labels;
32 Mat result = Mat::zeros(src.size(), CV_8UC3);
33
34 // 图像RGB像素数据转换为样本数据
35 int index = 0;
36 for (int row = 0; row < height; row++) {
37 for (int col = 0; col < width; col++) {
38 index = row*width + col;
39 Vec3b rgb = src.at<Vec3b>(row, col);
40 points.at<double>(index, 0) = static_cast<int>(rgb[0]);
41 points.at<double>(index, 1) = static_cast<int>(rgb[1]);
42 points.at<double>(index, 2) = static_cast<int>(rgb[2]);
43 }
44 }
45
46 // EM Cluster Train
47 Ptr<EM> em_model = EM::create();
48 em_model->setClustersNumber(numCluster);
49 em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);//设置协方差矩阵
50 //设置停止条件,训练100次结束
51 em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));
52 em_model->trainEM(points, noArray(), labels, noArray());
53
54 // 对每个像素标记颜色与显示
55 Mat sample(dims, 1, CV_64FC1);
56 double time = getTickCount();
57 int r = 0, g = 0, b = 0;
58 for (int row = 0; row < height; row++) {
59 for (int col = 0; col < width; col++) {
60 index = row*width + col;
61 int label = labels.at<int>(index, 0);
62 Scalar c = colors[label];
63 result.at<Vec3b>(row, col)[0] = c[0];
64 result.at<Vec3b>(row, col)[1] = c[1];
65 result.at<Vec3b>(row, col)[2] = c[2];
66
67 /*b = src.at<Vec3b>(row, col)[0];
68 g = src.at<Vec3b>(row, col)[1];
69 r = src.at<Vec3b>(row, col)[2];
70 sample.at<double>(0) = b;
71 sample.at<double>(1) = g;
72 sample.at<double>(2) = r;
73 int response = cvRound(em_model->predict2(sample, noArray())[1]);
74 Scalar c = colors[response];
75 result.at<Vec3b>(row, col)[0] = c[0];
76 result.at<Vec3b>(row, col)[1] = c[1];
77 result.at<Vec3b>(row, col)[2] = c[2];*/
78
79 }
80 }
81 printf("execution time(ms) : %.2f\n", (getTickCount() - time)/getTickFrequency()*1000);
82 imshow("EM-Segmentation", result);
83
84 waitKey(0);
85 return 0;
86 }