opencv::GMM(高斯混合模型)
GMM方法概述:基于高斯混合模型期望最大化。 高斯混合模型 (GMM) 高斯分布与概率密度分布 - PDF 初始化
初始化EM模型:
Ptr<EM> em_model = EM::create();
em_model->setClustersNumber(numCluster);
em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));
em_model->trainEM(points, noArray(), labels, noArray());
#include <opencv2/opencv.hpp> #include <iostream> using namespace cv; using namespace cv::ml; using namespace std; int main(int argc, char** argv) { Mat img = Mat::zeros(500, 500, CV_8UC3); RNG rng(12345); Scalar colorTab[] = { Scalar(0, 0, 255), Scalar(0, 255, 0), Scalar(255, 0, 0), Scalar(0, 255, 255), Scalar(255, 0, 255) }; int numCluster = rng.uniform(2, 5); printf("number of clusters : %d\n", numCluster); int sampleCount = rng.uniform(5, 1000); Mat points(sampleCount, 2, CV_32FC1); Mat labels; // 生成随机数 for (int k = 0; k < numCluster; k++) { Point center; center.x = rng.uniform(0, img.cols); center.y = rng.uniform(0, img.rows); Mat pointChunk = points.rowRange(k*sampleCount / numCluster, k == numCluster - 1 ? sampleCount : (k + 1)*sampleCount / numCluster); rng.fill(pointChunk, RNG::NORMAL, Scalar(center.x, center.y), Scalar(img.cols*0.05, img.rows*0.05)); } randShuffle(points, 1, &rng); //初始化EM模型 Ptr<EM> em_model = EM::create(); em_model->setClustersNumber(numCluster); em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL); em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1)); em_model->trainEM(points, noArray(), labels, noArray()); // 处理每个像素 Mat sample(1, 2, CV_32FC1); for (int row = 0; row < img.rows; row++) { for (int col = 0; col < img.cols; col++) { sample.at<float>(0) = (float)col; sample.at<float>(1) = (float)row; int response = cvRound(em_model->predict2(sample, noArray())[1]); Scalar c = colorTab[response]; //填充 circle(img, Point(col, row), 1, c*0.75, -1); } } // 画出采样数据 for (int i = 0; i < sampleCount; i++) { Point p(cvRound(points.at<float>(i, 0)), points.at<float>(i, 1)); circle(img, p, 1, colorTab[labels.at<int>(i)], -1); } imshow("GMM-EM Demo", img); waitKey(0); return 0; }
#include <opencv2/opencv.hpp> #include <iostream> using namespace cv; using namespace cv::ml; using namespace std; int main(int argc, char** argv) { Mat src = imread("D:/images/cvtest.png"); if (src.empty()) { printf("could not load iamge...\n"); return -1; } namedWindow("input image", CV_WINDOW_AUTOSIZE); imshow("input image", src); // 初始化 int numCluster = 4; const Scalar colors[] = { Scalar(255, 0, 0), Scalar(0, 255, 0), Scalar(0, 0, 255), Scalar(255, 255, 0) }; int width = src.cols; int height = src.rows; int dims = src.channels(); int nsamples = width * height; Mat points(nsamples, dims, CV_64FC1); Mat labels; Mat result = Mat::zeros(src.size(), CV_8UC3); // 图像RGB像素数据转换为样本数据 int index = 0; for (int row = 0; row < height; row++) { for (int col = 0; col < width; col++) { index = row * width + col; Vec3b rgb = src.at<Vec3b>(row, col); points.at<double>(index, 0) = static_cast<int>(rgb[0]); points.at<double>(index, 1) = static_cast<int>(rgb[1]); points.at<double>(index, 2) = static_cast<int>(rgb[2]); } } // EM Cluster Train Ptr<EM> em_model = EM::create(); em_model->setClustersNumber(numCluster); em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL); em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1)); em_model->trainEM(points, noArray(), labels, noArray()); // 对每个像素标记颜色与显示 Mat sample(dims, 1, CV_64FC1); double time = getTickCount(); int r = 0, g = 0, b = 0; for (int row = 0; row < height; row++) { for (int col = 0; col < width; col++) { /*index = row * width + col; int label = labels.at<int>(index, 0); Scalar c = colors[label]; result.at<Vec3b>(row, col)[0] = c[0]; result.at<Vec3b>(row, col)[1] = c[1]; result.at<Vec3b>(row, col)[2] = c[2];*/ b = src.at<Vec3b>(row, col)[0]; g = src.at<Vec3b>(row, col)[1]; r = src.at<Vec3b>(row, col)[2]; sample.at<double>(0) = b; sample.at<double>(1) = g; sample.at<double>(2) = r; int response = cvRound(em_model->predict2(sample, noArray())[1]); Scalar c = colors[response]; result.at<Vec3b>(row, col)[0] = c[0]; result.at<Vec3b>(row, col)[1] = c[1]; result.at<Vec3b>(row, col)[2] = c[2]; } } printf("execution time(ms) : %.2f\n", (getTickCount() - time) / getTickFrequency() * 1000); imshow("EM-Segmentation", result); waitKey(0); return 0; }