Kmeans
算法梗概
The k-means algorithm is one of the simplest yet most popular machine learning algorithms. It takes in the data points and the number of clusters (k) as input.
Next, it randomly plots k different points on the plane (called centroids). After the k centroids are randomly plotted, the following two steps are repeatedly performed until there is no further change in the set of k centroids:
- Assignment of points to the centroids: Every data point is assigned to the centroid that is the closest to it. The collection of data points assigned to a particular centroid is called a cluster. Therefore, the assignment of points to k centroids results in the formation of k clusters.
- Reassignment of centroids: In the next step, the centroid of every cluster is recomputed to be the center of the cluster (or the average of all the points in the cluster). All the data points are then reassigned to the new centroids:
Kmeans演示

代码
import numpy as np import matplotlib.pyplot as plt class Kmeans: """使用python和numpy实现Kmeans算法""" def __init__(self, k_): self.k = k_ # k是指定的簇的个数 self.threhold = 1e-10 self.last_k_cluster = None def fit(self, X): # 将X转变为ndarray结构 X = np.array(X) # 设置随机种子 np.random.seed(20) # 随机取k个向量作为初始簇中心 self.k_cluster = X[np.random.randint(0, len(X), self.k)] # 初始化X的标签 self.labels = np.zeros(len(X)) times = 0 plt.scatter(X[:,0], X[:, 1], c='black') plt.pause(1) while True: # 为X中的每个点分簇 for index, point in enumerate(X): # 对于X中的每一个向量point,计算point到每个簇中心的欧式距离的平方和 distance = np.sum(np.power(point-self.k_cluster, 2), axis=1) # 得益与numpy的广播特性,所以可以这么写 self.labels[index] = distance.argmin() # 将点point分为欧式距离的平方和最小的簇下标 # 作图 plt.scatter(X[:, 0], X[:, 1], c=self.labels, s=50) # 将刚分好簇的各点填色展示出来 plt.scatter(self.k_cluster[:,0], self.k_cluster[:,1], marker='X', c='black', s=100) plt.pause(0.5) # path = './Images/' + str(times) + '.jpg' # plt.savefig(path) # times += 1 # 更新每个簇的中心点,更新办法为"the average of all the points in the cluster" self.last_k_cluster = self.k_cluster.copy() # 保存上一次所有的簇中心 for i in range(self.k): self.k_cluster[i] = np.mean(X[self.labels == i], axis=0) # 比较新更新得到的簇中心,与上一次保留的所有簇中心的欧式距离和,如果这个和小于一个阈值,则跳出循环,算法结束 dist = np.sqrt(np.sum(np.power(self.last_k_cluster-self.k_cluster, 2))) if dist <= self.threhold: break def predict(self, X): # 将X转变为ndarray结构 X = np.array(X) result = np.zeros(len(X)) for index, point in enumerate(X): distance = np.sum(np.power(point - self.k_cluster, 2), axis=1) # 得益与numpy的广播特性,所以可以这么写 result[index] = distance.argmin() # 将点point分为欧式距离的平方和最小的簇下标 return result """测试代码""" # from KMeans_Shayue import * if __name__ == '__main__': obj = Kmeans(3) np.random.seed(10) X = np.random.randint(1, 300, (100, 2)) obj.fit(X)
【推荐】还在用 ECharts 开发大屏?试试这款永久免费的开源 BI 工具!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 软件产品开发中常见的10个问题及处理方法
· .NET 原生驾驭 AI 新基建实战系列:向量数据库的应用与畅想
· 从问题排查到源码分析:ActiveMQ消费端频繁日志刷屏的秘密
· 一次Java后端服务间歇性响应慢的问题排查记录
· dotnet 源代码生成器分析器入门
· 软件产品开发中常见的10个问题及处理方法
· 互联网不景气了那就玩玩嵌入式吧,用纯.NET开发并制作一个智能桌面机器人(四):结合BotSharp
· Vite CVE-2025-30208 安全漏洞
· MQ 如何保证数据一致性?
· 《HelloGitHub》第 108 期