GMM-实现聚类的代码示例

Matlab 代码:

 1 % GMM code
 2 
 3 function varargout = gmm(X, K_or_centroids)
 4 
 5     % input X:N-by-D data matrix
 6     % input K_or_centroids: K-by-D centroids
 7     
 8     % 阈值
 9     threshold = 1e-15;
10     % 读取数据维度
11     [N, D] = size(X);
12     % 判断输入质心是否为标量
13     if isscalar(K_or_centroids)
14         % 是标量,随机选取K个质心
15         K = K_or_centroids;
16         rnpm = randperm(N); % 打乱的N个序列
17         centroids = X(rnpm(1:K), :);
18     else   % 矩阵,给出每一类的初始化
19         K = size(K_or_centroids, 1);
20         centroids = K_or_centroids;
21     end
22     
23     % 定义模型初值
24     [pMiu pPi pSigma] = init_params();
25     
26     Lprev = -inf;
27     while true
28         % E-step,估算出概率值
29         % Px: N-by-K 
30         Px = calc_prob();
31         
32         % pGamma新的值,样本点所占的权重
33         % pPi:1-by-K     pGamma:N-by-K
34         pGamma = Px ./ repmat(pPi, N, 1);
35         % 对pGamma的每一行进行求和,sum(x,2):每一行求和
36         pGamma = pGamma ./ repmat(sum(pGamma, 2) ,1 , K);
37         
38         % M-step
39         % 每一个组件给予新的值
40         Nk = sum(pGamma,1);
41         pMiu = diag(1./Nk)*pGamma'*X;
42         pPi = Nk/N;
43         for kk = 1:K
44            Xshift = X - repmat(pMiu(kk, :) ,N, 1);
45            pSigma(:,:,kk) = (Xshift'*(diag(pGamma(:,kk))*Xshift)) / Nk(kk);
46         end
47         
48         % 观察收敛,convergence
49         L = sum(log(Px*pPi'));
50         if L-Lprev < threshold
51             break;
52         end
53         Lprev = L;
54         
55     end
56     
57     % 输出参数判定
58     if nargout == 1
59         varargout = {Px};
60     else
61         model = [];
62         model.Miu = pMiu;
63         model.Sigma = pSigma;
64         model.Pi = pPi;
65         varargout = {Px, model};
66     end
67     
68     function [pMiu pPi pSigma] = init_params()
69        pMiu = centroids; % 均值,K类的中心
70        pPi = zeros(1, K); % 概率
71        pSigma = zeros(D, D, K); % 协方差,每一个都是D-by-D
72        
73        % (X - pMiu)^2 = X^2 + pMiu^2 - 2*X*pMiu
74        distmat = repmat(sum(X.*X, 2), 1, K) + repmat(sum(pMiu.*pMiu, 2)', N, 1) - 2*X*pMiu';
75        [dummy labels] = min(distmat, [], 2); % 找出每一行的最小值,并标出列的位置
76        
77        for k=1:K   %初始化参数
78            Xk = X(labels == k, :);
79            pPi(k) = size(Xk, 1)/N;
80            pSigma(:, :, k) = cov(Xk);
81        end             
82     end
83 
84     % 计算概率值
85     function Px = calc_prob()
86         Px = zeros(N,K);
87         for k=1:K
88             Xshift = X - repmat(pMiu(k,:),N,1);
89             inv_pSigma = inv(pSigma(:,:,k)+diag(repmat(threshold, 1, size(pSigma(:,:,k),1))));
90             tmp = sum((Xshift*inv_pSigma).*Xshift, 2);
91             coef = (2*pi)^(-D/2)*sqrt(det(inv_pSigma));
92             Px(:,k) = coef * exp(-1/2*tmp);
93         end
94     end
95       
96 
97 end

 

测试主程序:

 1 % 测试代码
 2 clear all
 3 clc
 4 
 5 data = load('testSet.txt');
 6 [PX, Model] = gmm(data, 4);
 7 [~,index] = max(PX'); % 每一列的最大值
 8 
 9 cent = Model.Miu;
10 figure
11 I = find(index == 1);
12 scatter(data(I,1), data(I,2))
13 hold on 
14 scatter(cent(1,1), cent(1,2) ,150, 'filled');
15 hold on
16 I = find(index == 2);
17 scatter(data(I,1),data(I,2))
18 hold on
19 scatter(cent(2,1),cent(2,2),150,'filled')
20 hold on
21 I = find(index == 3);
22 scatter(data(I,1),data(I,2))
23 hold on
24 scatter(cent(3,1),cent(3,2),150,'filled')
25 hold on
26 I = find(index == 4);
27 scatter(data(I,1),data(I,2))
28 hold on
29 scatter(cent(4,1),cent(4,2),150,'filled')

 

示意图:

 

 参考自:http://www.voidcn.com/blog/llp1992/article/p-2308490.html

 

posted @ 2017-07-06 19:19  今夜无风  阅读(6781)  评论(0编辑  收藏  举报