二分-k均值算法

首先我们都知道k均值算法有一个炒鸡大的bug,就是在很多情况下他只会收敛到局部最小值而不是全局最小值,为了解决这个问题,很多学者提出了很多的方法,我们在这里介绍一种叫做2分k均值的方法。

该算法首先将所有点作为一个簇,然后将该簇一分为二。之后选择其中一个簇继续进行划分,选择哪一个簇进行划分取决于哪个簇的sse是最大值。上述基于sse的划分过程不断重复,直到得到用户指定的簇数目为止。

将所有的点看成一个簇,当粗的数目小于k时,对每一个簇计算总误差,在给定的粗上进行k均值聚类(k=2),计算将该粗一分为二之后的总误差。最后选择sse最大的簇进行划分。重复执行若干次直到簇的数目等于k。

 

1.首先贴出k均值函数

# coding=utf-8

from numpy import *
import matplotlib
import matplotlib.pyplot as plt
import operator
from os import listdir
import time


def distEclud(vecA, vecB):
    return sqrt(sum(power(vecA - vecB, 2)))  # la.norm(vecA-vecB)


def randCent(dataSet, k):
    n = shape(dataSet)[1]
    centroids = mat(zeros((k, n)))  # create centroid mat
    for j in range(n):  # create random cluster centers, within bounds of each dimension
        minJ = min(dataSet[:, j])
        rangeJ = float(max(dataSet[:, j]) - minJ)
        centroids[:, j] = mat(minJ + rangeJ * random.rand(k, 1))
    return centroids


def kMeans(dataSet, k, distMeas=distEclud, createCent=randCent):
    m = dataSet.shape[0]
    clusterAssment = zeros((m, 2))


    centroids = createCent(dataSet, k)
    # print centroids
    show(centroids)
    clusterChanged = True
    while clusterChanged:
        clusterChanged = False

        for i in range(m):
            point = dataSet[i, :]  # 遍历每个点
            mindist = inf
            minindex = -1

            for n in range(k):
                heart = centroids[n, :]  # 遍历每个质心
                distance = distMeas(point, heart)  # 求点与质心距离
                if distance < mindist:
                    mindist = distance  # 更新最小距离mindist
                    minindex = n  # 更新最小距离的质心序号

            if clusterAssment[i, 0] != minindex: clusterChanged = True
            clusterAssment[i, :] = minindex, mindist ** 2  # 方差
        # print clusterAssment
        for cent in range(k):
            ptsInClust = dataSet[(clusterAssment[:, 0] == cent)]  # get all the point in this cluster
            # print ptsInClust
            if len(ptsInClust):
                centroids[cent, :] = mean(ptsInClust, axis=0)  # assign centroid to mean
            else:
                centroids[cent, :] = array([[0, 0]])
        show(centroids,color='green')

    return centroids, clusterAssment



def show(data,color=None):
    if not color:
        color='green'
    group=createDataSet()
    fig = plt.figure(1)
    axes = fig.add_subplot(111)
    axes.scatter(group[:, 0], group[:, 1], s=40, c='red')
    axes.scatter(data[:, 0], data[:, 1], s=50, c=color)
    plt.show()



def createDataSet():
    group = array([[1.0, 1.1], [1.0, 1.0],
                   [0, 0], [0, 0.1],
                   [2, 1.0], [2.1, 0.9],
                   [0.3, 0.0], [1.1, 0.9],
                   [2.2, 1.0], [2.1, 0.8],
                   [3.3, 3.5], [2.1, 0.9],
                   [2, 1.0], [2.1, 0.9],
                   [3.5, 3.4], [3.6, 3.5]
                   ])


    return group
# centroids, clusterAssment=kMeans(createDataSet(),4)
# show(centroids,color='yellow')

  



在此基础上我们加上二分的算法

def biKmeans(dataSet, k, distMeas=distEclud):
    m = shape(dataSet)[0]     #点数
    clusterAssment = mat(zeros((m,2)))   #空矩阵
    centroid0 = mean(dataSet, axis=0).tolist()  #数据集的平均值,tolist是转换为列表
    # print centroid0
    centList = [centroid0]  # create a list with one centroid  对每一个质心建立一个列表容器
    for j in range(m):     #遍历数据集
        clusterAssment[j,1]=distMeas(mat(centroid0),dataSet[j,:])**2  #求出数据集中每一个点到先前选定质心的距离平方
                                                      #并将其赋值给clusterAssment这个矩阵的对应列的第二个值。
                                                    #而第一个值全都赋值为0表示当前只有一个簇
        # print mat(centroid0),dataSet[j,:]
    # print clusterAssment

    while (len(centList) < k):  #簇的数量小于4时
        lowestSSE = inf            #初始化lowersse为正无穷
        for i in range(len(centList)):    # 遍历当前存在的每一个质心,整个遍历过程是一个找质心的过程,但即便是2分,得到的结果也是不确定的
                          #这个循环的目的只是得到划分哪个质心可以得到最大效益,而不考虑如何划分,主要的原因是k均值存在很
                         # 大偶然性,不能得到确切结果。
            print i

            ptsInCurrCluster = dataSet[nonzero(clusterAssment[:, 0].A == i)[0], :] # 每一个簇所拥有的所有数据集
            # print ptsInCurrCluster               #当前for 循环中只有一个簇
            centroidMat, splitClustAss = kMeans(ptsInCurrCluster, 2, distMeas)       #k-均值算法,k=2
            # print centroidMat
            # axes.scatter(centroidMat[:, 0], centroidMat[:, 1], s=40, c='blue')         #将得到的两个质心描绘出来
            # plt.show()
            # print splitClustAss                      #点分配到两个质心的分配方式矩阵
            sseSplit = sum(splitClustAss[:, 1])           #分配后的点方差之和   误差sse
            # print sseSplit


            sseNotSplit = sum(clusterAssment[nonzero(clusterAssment[:, 0].A != i)[0], 1])  #不在当前分配簇中的点方差之和

            print "sseSplit, and notSplit: ", sseSplit, sseNotSplit
            if (sseSplit + sseNotSplit) < lowestSSE:   # 两种方差和小于先前的lowersse,则说明这种分配方式减小了误差率
                bestCentToSplit = i                         #暂时将当前划分方式设为最佳
                bestNewCents = centroidMat            #暂时将当前的划分质心设为最好
                bestClustAss = splitClustAss.copy()     #复制当前划分的   点分配到两个质心的分配方式矩阵
                lowestSSE = sseSplit + sseNotSplit      #更新最小lowersse

          # 将得到的两个最佳质心描绘出来
        # show(bestNewCents)
        # print bestClustAss
        bestClustAss[nonzero(bestClustAss[:, 0]== 1)[0], 0] = len(centList)  # 2分k聚类返回系数0,或1  ,需要把1换成当前簇数目,以免造成重复
        bestClustAss[nonzero(bestClustAss[:, 0] == 0)[0], 0] = bestCentToSplit  #把0换成别切分的簇,或者与上面的交换赋值也可以
        print 'the bestCentToSplit is: ', bestCentToSplit
        print 'the len of bestClustAss is: ', len(bestClustAss)

        centList[bestCentToSplit] = bestNewCents[0, :].tolist()[0]  # 将centlist指定位置上的质心换为分割后的质心
        centList.append(bestNewCents[1, :].tolist()[0])                #将另一个质心添加上去

        clusterAssment[nonzero(clusterAssment[:, 0].A == bestCentToSplit)[0],:] = bestClustAss  # 将划分后的新质心及点分布赋值给结果矩阵
        # print clusterAssment
        show(mat(centList),color="blue")
    return mat(centList), clusterAssment

cent,clusterAssment= biKmeans(createDataSet(), 4)
show(cent,color='yellow')

这样我们就结束了2分k均值的编写,效果得到了很大的优化。




posted @ 2016-11-21 18:05  GreadLoveJM  阅读(1879)  评论(0编辑  收藏  举报