使用西瓜数据集4.0实现k-means算法

% 基于西瓜数据集4.0的K均值算法
%% 数据
data=[0.697,0.460;0.771,0.376;0.634,0.264;0.608,0.318;0.556,0.215;0.403,0.237;0.481,0.149;0.437,0.211;0.666,0.091;0.243,0.267;0.245,0.057;0.343,0.099;0.639,0.161;0.657,0.198;0.360,0.370;0.593,0.042;0.719,0.103;0.359,0.188;0.339,0.241;0.282,0.257;0.748,0.232;0.714,0.246;0.483,0.312;0.478,0.437;0.525,0.369;0.751,0.489;0.532,0.472;0.473,0.376;0.725,0.445;0.446,0.459];
%% k=3 随机抽取三个样本作为初始均值向量
n=30;
k=3;
flag=randperm(n,k);
data(flag(1),3)=1;
data(flag(2),3)=2;
data(flag(3),3)=3;
mu=zeros(k,2);
for i=1:k
    for j=1:2
        mu(i,j)=data(flag(i),j);
    end
end
%% 计算欧式距离,选择最近的原型进行分簇
dis=zeros(n,k);
for i=1:n
    min=10000;
    category = 1;
    if (data(i,1)==mu(1,1)&&data(i,2)==mu(1,2))||(data(i,1)==mu(2,1)&&data(i,2)==mu(2,2))||(data(i,1)==mu(3,1)&&data(i,2)==mu(3,2))
        continue;
    else
        for j=1:k
            dis(i,j)=sqrt((data(i,1)-mu(j,1)).^2+(data(i,2)-mu(j,2)).^2);
            if dis(i,j)<min
                min= dis(i,j);
                category = j;
            end
        end
        data(i,3)=category;
    end
end
%% 画图
figure(1);
x=data(:,1);
y=data(:,2);
plot(mu(1,1),mu(1,2),'r+');
hold on
plot(mu(2,1),mu(2,2),'g+');
hold on
plot(mu(3,1),mu(3,2),'b+');
hold on
for i=1:30
    if data(i,3)==1&&(data(i,1)~=mu(1,1)&&data(i,2)~=mu(1,2))
        plot(x(i),y(i),'r*');
        hold on
    end
    if data(i,3)==2&&(data(i,1)~=mu(2,1)&&data(i,2)~=mu(2,2))
        plot(x(i),y(i),'g*');
        hold on
    end
    if data(i,3)==3&&(data(i,1)~=mu(3,1)&&data(i,2)~=mu(3,2))
        plot(x(i),y(i),'b*');
        hold on
    end
end
xlabel('密度');
ylabel('含糖量');
title('初始')
%% 循环
for iter=2:500
    % 重新计算mu
    sum=zeros(k,2);
    num=zeros(1,k);
    for j=1:n
        if data(j,3)==1
            sum(1,1)=sum(1,1)+data(j,1);
            sum(1,2)=sum(1,2)+data(j,2);
            num(1,1)=num(1,1)+1;
        elseif data(j,3)==2
            sum(2,1)=sum(2,1)+data(j,1);
            sum(2,2)=sum(2,2)+data(j,2);
            num(1,2)=num(1,2)+1;
        else
            sum(3,1)=sum(3,1)+data(j,1);
            sum(3,2)=sum(3,2)+data(j,2);
            num(1,3)=num(1,3)+1;
        end
    end
    if (mu(1,1)==sum(1,1)/num(1,1)&&mu(1,2)==sum(1,2)/num(1,1))&&(mu(2,1)==sum(2,1)/num(1,2)&&mu(2,2)==sum(2,2)/num(1,2))&&(mu(3,1)==sum(3,1)/num(1,3)&&mu(3,2)==sum(3,2)/num(1,3))
        % 无更新结束循环
        disp(iter-1);
        break;
    else
        % 更新 mu
        mu(1,1)=sum(1,1)/num(1);
        mu(1,2)=sum(1,2)/num(1);
        mu(2,1)=sum(2,1)/num(2);
        mu(2,2)=sum(2,2)/num(2);
        mu(3,1)=sum(3,1)/num(3);
        mu(3,2)=sum(3,2)/num(3);
    end
    for i=1:n
        min=10000;
        category = 1;
        if (data(i,1)==mu(1,1)&&data(i,2)==mu(1,2))||(data(i,1)==mu(2,1)&&data(i,2)==mu(2,2))||(data(i,1)==mu(3,1)&&data(i,2)==mu(3,2))
            continue;
        else
            for j=1:k
                dis(i,j)=sqrt((data(i,1)-mu(j,1)).^2+(data(i,2)-mu(j,2)).^2);
                if dis(i,j)<min
                    min= dis(i,j);
                    category = j;
                end
            end
            data(i,3)=category;
        end
    end
end
%% 画图
figure(2);
x=data(:,1);
y=data(:,2);
plot(mu(1,1),mu(1,2),'r+');
hold on
plot(mu(2,1),mu(2,2),'g+');
hold on
plot(mu(3,1),mu(3,2),'b+');
hold on
for i=1:30
    if data(i,3)==1&&(data(i,1)~=mu(1,1)&&data(i,2)~=mu(1,2))
        plot(x(i),y(i),'r*');
        hold on
    end
    if data(i,3)==2&&(data(i,1)~=mu(2,1)&&data(i,2)~=mu(2,2))
        plot(x(i),y(i),'g*');
        hold on
    end
    if data(i,3)==3&&(data(i,1)~=mu(3,1)&&data(i,2)~=mu(3,2))
        plot(x(i),y(i),'b*');
        hold on
    end
end
xlabel('密度');
ylabel('含糖量');
title('结果')
posted @ 2021-10-31 10:04  0x3fffffff  阅读(855)  评论(0编辑  收藏  举报