K-均值
K均值聚类
我们现在考虑这个问题:寻找多维空间中数据点的分组或者聚类问题。假设有一个数据集 ,它是D维欧几里得空间中的随机变量 的N次观测组成的。我们的目标是将数据集划分为K个类别。先假定K的值是给定的。
直观上讲,我们认为一组数据点中的一个聚类中,内部点之间的距离应该小于该聚类中数据点与外部点之间的距离。不妨设这个聚类为区域中有M个数据点 ,其中 是数字 的一个排列。那么直白地翻译上述为 可以看出这种形式化的表示比较繁琐。
现在引入一组D维向量 是K个聚类的代表,认为是对应聚类的中心。那么我们的目标就是找到一组向量 及数据点所属的聚类,使得每个数据点和它最近的向量 之间的距离是最小的,即 是所有选择中最小的。
为了方便构建目标函数,引入了二值指示变量 对应同一个数据点,这n个值只有一个等于1,其余都等于0。如果 即数据点属于聚类k,这种表示方式称之为"1 of K"。为此我们可以定义一个目标函数:
这个公式涉及到两种量,聚类中心 和 分配, 分配会随着聚类中心发生变化。可以使用一种迭代的方法求解使得目标函数最小。
第一阶段是固定 ,关于 最小化;第二阶段是固定 关于 最小化。
第一阶段是固定 ,关于 最小化。由于不同n相关的项是相互独立的,因此可以对于每个n分别进行最优化,只要k的值使得 最小,我们就令 ,即
第二阶段是固定 关于 最小化。目标函数是 的一个二次函数,令它关于的导数等于零,即可达到最小值,即
可以很容易求得
这个公式中分母表示聚类k中数据点的数量,分母是聚类k中数据点之和,合在一起就是聚类k中所有数据点的均值。因此求得的聚类k的聚类中心就等于属于这个聚类的数据点的均值。这就是K均值算法(K-means)的由来。为数据点分配聚类的步骤和计算聚类中心的步骤不停迭代进行,直到聚类的分配不再改变或者直到达到一定的迭代次数停止。由于每个阶段都减小了目标函数的值,因此算法的收敛性得到保证。
下面以随机数和老忠实间歇喷泉为例,来编程实现K-means代码。
示例1
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from prml.clustering import KMeans
from prml.rv import (
MultivariateGaussianMixture,
BernoulliMixture
)
# training data
x1 = np.random.normal(size=(100, 2))
x1 += np.array([-5, -5])
x2 = np.random.normal(size=(100, 2))
x2 += np.array([5, -5])
x3 = np.random.normal(size=(100, 2))
x3 += np.array([0, 5])
x_train = np.vstack((x1, x2, x3))
x0, x1 = np.meshgrid(np.linspace(-10, 10, 100), np.linspace(-10, 10, 100))
x = np.array([x0, x1]).reshape(2, -1).T
kmeans = KMeans(n_clusters=3)
kmeans.fit(x_train)
cluster = kmeans.predict(x_train)
plt.scatter(x_train[:, 0], x_train[:, 1], c=cluster)
plt.scatter(kmeans.centers[:, 0], kmeans.centers[:, 1], s=200, marker='X', lw=2, c=['purple', 'cyan', 'yellow'], edgecolor="white")
plt.contourf(x0, x1, kmeans.predict(x).reshape(100, 100), alpha=0.1)
plt.xlim(-10, 10)
plt.ylim(-10, 10)
plt.gca().set_aspect('equal', adjustable='box')
plt.show()

示例2
import pandas as pd
csv_data = pd.read_csv('data.csv')
x_train = csv_data.to_numpy()
kmeans = KMeans(n_clusters=2)
kmeans.fit(x_train)
cluster = kmeans.predict(x_train)
plt.scatter(x_train[:, 0], x_train[:, 1], c=cluster)
plt.scatter(kmeans.centers[:, 0], kmeans.centers[:, 1], s=200, marker='X', lw=2, c=['purple', 'cyan'], edgecolor="white")
plt.grid(linestyle='-.')
plt.show()

【参考】
1.Pattern Recognition and Machine Learing 中/英文版
2.https://github.com/ctgk/PRML
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
· 三行代码完成国际化适配,妙~啊~