k-means、k-means++、kernel k-means算法介绍及在datasets-load_iris数据集上的实现

k-means、k-means++、kernel k-means算法介绍及在datasets-load_iris数据集实现

k-means算法作为经典的聚类算法,提出的时间较早,发展至今衍生出很多变体

k-means++作为k-means的改进,优化了其对初始类中心点的选取

kernel k-means利用数据的维度变化,通过提升维度巧妙地解决了k-means只能作用于线性可分数据

聚类目标:

  • 处于同一类之间的点距离较近(相似度较大)
  • 处于不同类之间的点距离较远(相似度较小)

完整实验代码

https://github.com/yangbo981205/k-means-clustering.git

k-means

是一种迭代求解的聚类分析算法,其步骤是,预将数据分为K组,则随机选取K个对象作为初始的聚类中心,然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。每分配一个样本,聚类的聚类中心会根据聚类中现有的对象被重新计算。这个过程将不断重复直到满足某个终止条件。终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。

公式

m i n A 1 , ⋯   , A c ∑ i = 1 c ∑ x j ∈ A i ∥ x j − m i ∥ 2 2 min_{A_1,\cdots ,A_c}\sum_{i=1}^c \sum_{x_j\in A_i}\Vert x_j-m_i\Vert_2^2 minA1,,Aci=1cxjAixjmi22

其中, m i m_i mi 表示第 i i i 个类的均值, c c c 是类的个数。

k-means算法思路较为简单,接下来的部分为代码实现

数据集处理

# 数据获取
def get_data():
    iris = load_iris()
    data = iris.data
    result = iris.target
    return data, result


# 为方便绘图,将数据集取两个维度
def data_processing(data):
    data_list = []
    for i in data:
        tem_list = [i[2], i[3]]
        data_list.append(tem_list)
    return data_list

选取初始类中心

k-means的初始类中心选取是随机进行的,所以我们根据给定的类别数进行随机选取

# k-means类中心的选择
def select_centre(data, k):
    center_list = []
    for i in range(k):
        ran = np.random.randint(len(data))
        center_list.append(data[ran])
    return center_list

单次k-means运算

k-means是一种迭代求最优解的算法,其每一次的迭代过程如下:

# k-means聚类
def k_means(data, center):
    new_center = []
    cluster_result = []
    class_result = [[] for i in range(len(center))]
    for d in data:
        dis_list = []
        for c in center:
            # 欧氏距离进行距离度量
            dis = np.sqrt(np.sum(np.square(np.array(d)-np.array(c))))
            dis_list.append(dis)
        cluster_result.append(dis_list.index(min(dis_list)))
        class_result[dis_list.index(min(dis_list))].append(d)

    # 更新类中心
    for cls in class_result:
        x = 0
        y = 0
        for c in cls:
            x += c[0]
            y += c[1]
        if len(cls) == 0:
            new_center.append(center[class_result.index(cls)])
        else:
            new_center.append([round(x/len(cls), 4), round(y/len(cls), 4)])

    print(new_center)
    # print(len(class_result[0]), len(class_result[1]), len(class_result[2]))

    return cluster_result, class_result, new_center

兰德指数对聚类结果评判

def metric(result, pred_result):
    print(metrics.adjusted_rand_score(result, pred_result))

运行结果

正确结果

datasets-load_iris数据集是有标注的数据集,因此我们可以用来判断我们聚类结果的准确性

在这里插入图片描述

每一轮迭代的结果

三角形表示类中心,圆圈为类中的样本点

在这里插入图片描述

优点

  • 容易理解,聚类效果不错,虽然是局部最优, 但往往局部最优就够了;
  • 处理大数据集的时候,该算法可以保证较好的伸缩性;
  • 当簇近似高斯分布的时候,效果非常不错;
  • 算法复杂度低。

缺点

  • K 值需要人为设定,不同 K 值得到的结果不一样;
  • 对初始的簇中心敏感,不同选取方式会得到不同结果;
  • 对异常值敏感;
  • 样本只能归为一类,不适合多分类任务;
  • 不适合太离散的分类、样本类别不平衡的分类、非凸形状的分类。

k-means++

k-means++作为一种k-means的改进,其主要工作为:优化了初始类中心的选取,使得初始类中心的距离相对较远,一介绍优化过程中的迭代次数

K-means++ 能显著的改善分类结果的最终误差,尽管计算初始点时花费了额外的时间,但是在迭代过程中,k-mean 本身能快速收敛,因此算法实际上降低了计算时间。

代码与k-means相同,只是初始类中心的选取进行了改进

初始类中心选取

# k-means++类中心的选择
def select_centre(data, k):
    center_list = []
    ran = np.random.randint(len(data))
    center_list.append(data[ran])

    for i in range(k-1):
        # 构建数据点与中心点的距离
        dis_list = []
        for c in center_list:
            tmp_list = []
            for d in data:
                dis = np.sqrt(np.sum(np.square(np.array(d) - np.array(c))))
                tmp_list.append(dis)
            dis_list.append(tmp_list)
        # 计算每个中心点与其他点的距离和
        sum_dis_list = [0 for j in range(len(data))]
        for d in dis_list:
            sum_dis_list = np.array(sum_dis_list) + np.array(d)
        # 取距离和最大的点作为新的中心点
        center_list.append(data[list(sum_dis_list).index(max(sum_dis_list))])

    return center_list

运行结果

每一轮迭代的结果

在这里插入图片描述

kernel k-means

k-means聚类算法所解决的问题为线性可分的问题,那么对于下图所示数据该如何进行聚类呢?

在这里插入图片描述

如果使用k-means进行聚类,势必会得到这样的结果:

在这里插入图片描述

可以看出与想要的结果有很大的差距。

原理

kernel k-means的原理为将二维数据转换为三维数据,如图所示:

在这里插入图片描述

坐标变换方式

本文做的变换方式为使用 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2作为第三个维度,实际中可以根据自己需求进行设计不同的维度变换法则。本文中:

(x,y)转换为(x, 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2,y)(x,y , 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2)( 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2,x, y)

为什么选择 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2作为第三个维度?

在我们选择第三个维度时,尽量选取区分度较大的一种变换作为第三个维度,这样可以防止出现如下图所示情况:

当我们选择(x,y , 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 \sqrt{2*d[0]^2+2*d[1]^2} 2d[0]2+2d[1]2 )变幻时:

在这里插入图片描述

很显然,上图并没有很好的区分不同类簇,达到聚类的效果。

三种维度变换下的结果分别为:

2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2,x, y)

在这里插入图片描述

(x, 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2,y)

在这里插入图片描述

(x,y , 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2

在这里插入图片描述

可以看到,原本在二维空间中非线性可分的数据集在转换为三维后边的线性可分(可以被一个平面分为两个部分)
同时我们也可以看到,在上面的图中,虽然数据点被明显的分为了两个部分,但是观测坐标轴会发现,在坐标轴刻度上用于区分两个类块距离的那一条轴距离很近,在进行聚类任务时就会出现如下情况

在这里插入图片描述

为防止这种情况的发生,我们要尽可能的取扩大新加入维度对聚类结果的影响,所以采用

(x,y)转换为(x, 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2,y)(x,y , 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2)( 2 ∗ d [ 0 ] 2 + 2 ∗ d [ 1 ] 2 2*d[0]^2+2*d[1]^2 2d[0]2+2d[1]2,x, y)

作为维度的变换

实验结果

在这里插入图片描述

代码

数据集生成:

# 生成数据集
def make_data():
    return datasets.make_circles(n_samples=200, factor=.5, noise=.05)

数据集整理及原始图像绘制:

# 数据集处理
def data_processing(data):
    data_ = list(data[0])
    data_list = []
    for d in data_:
        data_list.append(list(d))
    result = list(data[1])

    return data_list, result


# 数据集展示
def show_data(data, result):
    for d in data:
        if result[data.index(d)] == 0:
            plt.scatter(d[0], d[1], c="r")
        if result[data.index(d)] == 1:
            plt.scatter(d[0], d[1], c="g")
    plt.show()

数据向高维映射,及3D图像的绘制:

# 将原始二维数据点映射到三维
# 引入一个新的维度,数据点与远点的距离,转换:(x,y)转换为(x,sqrt(x**2+y**2),y)
def axes3d(data):
    new_data_list = []
    for d in data:
        new_data = [np.sqrt(d[0]**2+d[1]**2), d[0], d[1]]
        new_data_list.append(new_data)

    return new_data_list


# 三维图像绘制
def show_data_3d(data, result):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for d in data:
        if result[data.index(d)] == 0:
            ax.scatter(d[0], d[1], d[2], c="r", marker="8")
        if result[data.index(d)] == 1:
            ax.scatter(d[0], d[1], d[2], c="g", marker="8")
    plt.show()
posted @ 2021-05-19 17:34  博0_oer~  阅读(179)  评论(0编辑  收藏  举报