opencv::GMM(高斯混合模型)

 

GMM方法概述:基于高斯混合模型期望最大化。
    高斯混合模型 (GMM) 
    高斯分布与概率密度分布 - PDF 
    初始化 

 初始化EM模型:
    Ptr<EM> em_model = EM::create();
    em_model->setClustersNumber(numCluster);
    em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
    em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));
    em_model->trainEM(points, noArray(), labels, noArray());

 

 

#include <opencv2/opencv.hpp>
#include <iostream>

using namespace cv;
using namespace cv::ml;
using namespace std;

int main(int argc, char** argv) {
    Mat img = Mat::zeros(500, 500, CV_8UC3);
    RNG rng(12345);

    Scalar colorTab[] = {
        Scalar(0, 0, 255),
        Scalar(0, 255, 0),
        Scalar(255, 0, 0),
        Scalar(0, 255, 255),
        Scalar(255, 0, 255)
    };

    int numCluster = rng.uniform(2, 5);
    printf("number of clusters : %d\n", numCluster);

    int sampleCount = rng.uniform(5, 1000);
    Mat points(sampleCount, 2, CV_32FC1);
    Mat labels;

    // 生成随机数
    for (int k = 0; k < numCluster; k++) {
        Point center;
        center.x = rng.uniform(0, img.cols);
        center.y = rng.uniform(0, img.rows);
        Mat pointChunk = points.rowRange(k*sampleCount / numCluster,
            k == numCluster - 1 ? sampleCount : (k + 1)*sampleCount / numCluster);

        rng.fill(pointChunk, RNG::NORMAL, Scalar(center.x, center.y), Scalar(img.cols*0.05, img.rows*0.05));
    }

    randShuffle(points, 1, &rng);
    //初始化EM模型
    Ptr<EM> em_model = EM::create();
    em_model->setClustersNumber(numCluster);
    em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
    em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));
    em_model->trainEM(points, noArray(), labels, noArray());

    // 处理每个像素
    Mat sample(1, 2, CV_32FC1);
    for (int row = 0; row < img.rows; row++) {
        for (int col = 0; col < img.cols; col++) {
            sample.at<float>(0) = (float)col;
            sample.at<float>(1) = (float)row;
            int response = cvRound(em_model->predict2(sample, noArray())[1]);
            Scalar c = colorTab[response];
            //填充
            circle(img, Point(col, row), 1, c*0.75, -1);
        }
    }

    // 画出采样数据
    for (int i = 0; i < sampleCount; i++) {
        Point p(cvRound(points.at<float>(i, 0)), points.at<float>(i, 1));
        circle(img, p, 1, colorTab[labels.at<int>(i)], -1);
    }

    imshow("GMM-EM Demo", img);

    waitKey(0);
    return 0;
}

 

 

 

 

#include <opencv2/opencv.hpp>
#include <iostream>

using namespace cv;
using namespace cv::ml;
using namespace std;

int main(int argc, char** argv) {
    Mat src = imread("D:/images/cvtest.png");
    if (src.empty()) {
        printf("could not load iamge...\n");
        return -1;
    }
    namedWindow("input image", CV_WINDOW_AUTOSIZE);
    imshow("input image", src);

    // 初始化
    int numCluster = 4;
    const Scalar colors[] = {
        Scalar(255, 0, 0),
        Scalar(0, 255, 0),
        Scalar(0, 0, 255),
        Scalar(255, 255, 0)
    };

    int width = src.cols;
    int height = src.rows;
    int dims = src.channels();
    int nsamples = width * height;
    Mat points(nsamples, dims, CV_64FC1);
    Mat labels;
    Mat result = Mat::zeros(src.size(), CV_8UC3);

    // 图像RGB像素数据转换为样本数据 
    int index = 0;
    for (int row = 0; row < height; row++) {
        for (int col = 0; col < width; col++) {
            index = row * width + col;
            Vec3b rgb = src.at<Vec3b>(row, col);
            points.at<double>(index, 0) = static_cast<int>(rgb[0]);
            points.at<double>(index, 1) = static_cast<int>(rgb[1]);
            points.at<double>(index, 2) = static_cast<int>(rgb[2]);
        }
    }

    // EM Cluster Train
    Ptr<EM> em_model = EM::create();
    em_model->setClustersNumber(numCluster);
    em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
    em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));
    em_model->trainEM(points, noArray(), labels, noArray());

    // 对每个像素标记颜色与显示
    Mat sample(dims, 1, CV_64FC1);
    double time = getTickCount();
    int r = 0, g = 0, b = 0;
    for (int row = 0; row < height; row++) {
        for (int col = 0; col < width; col++) {
            /*index = row * width + col;
            int label = labels.at<int>(index, 0);
            Scalar c = colors[label];
            result.at<Vec3b>(row, col)[0] = c[0];
            result.at<Vec3b>(row, col)[1] = c[1];
            result.at<Vec3b>(row, col)[2] = c[2];*/
            
            b = src.at<Vec3b>(row, col)[0];
            g = src.at<Vec3b>(row, col)[1];
            r = src.at<Vec3b>(row, col)[2];
            sample.at<double>(0) = b;
            sample.at<double>(1) = g;
            sample.at<double>(2) = r;
            int response = cvRound(em_model->predict2(sample, noArray())[1]);
            Scalar c = colors[response];
            result.at<Vec3b>(row, col)[0] = c[0];
            result.at<Vec3b>(row, col)[1] = c[1];
            result.at<Vec3b>(row, col)[2] = c[2];
        }
    }
    printf("execution time(ms) : %.2f\n", (getTickCount() - time) / getTickFrequency() * 1000);
    imshow("EM-Segmentation", result);

    waitKey(0);
    return 0;
}

 

posted @ 2019-10-25 13:47  osbreak  阅读(1027)  评论(0编辑  收藏  举报