基于K-means聚类的图像分割
K-means算法用于聚类分析,广泛用于机器学习领域。
下面借用百度百科的解释,个人觉得讲的还算清楚:
k-means 算法接受参数 k ;然后将事先输入的n个数据对象划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。
K-means算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一。K-means算法的基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。
假设要把样本集分为c个类别,算法描述如下:
(1)适当选择c个类的初始中心;
(2)在第k次迭代中,对任意一个样本,求其到c各中心的距离,将该样本归到距离最短的中心所在的类;
(3)利用均值等方法更新该类的中心值;
(4)对于所有的c个聚类中心,如果利用(2)(3)的迭代法更新后,值保持不变,则迭代结束,否则继续迭代。
该算法的最大优势在于简洁和快速。算法的关键在于初始中心的选择和距离公式。
算法流程
首先从n个数据对象任意选择 k 个对象作为初始聚类中心;而对于所剩下其它对象,则根据它们与这些聚类中心的相似度(距离),分别将它们分配给与其最相似的(聚类中心所代表的)聚类;然后再计算每个所获新聚类的聚类中心(该聚类中所有对象的均值);不断重复这一过程直到标准测度函数开始收敛为止。一般都采用均方差作为标准测度函数. k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。
具体流程
输入:k, data[n];
(1) 选择k个初始中心点,例如c[0]=data[0],…c[k-1]=data[k-1];
(2) 对于data[0]….data[n], 分别与c[0]…c[k-1]比较,假定与c[i]差值最少,就标记为i;
(3) 对于所有标记为i点,重新计算c[i]={ 所有标记为i的data[j]之和}/标记为i的个数;
(4) 重复(2)(3),直到所有c[i]值的变化小于给定阈值。
如何处理空的聚类
果所有的点在指派步骤都未分配到某个簇,就会得到空簇。如果这种情况发生,则需要某种策略来选择一个替补质心,否则的话,平方误差将会偏大。一种方法是选择一个距离当前任何质心最远的点。这将消除当前对总平方误差影响最大的点。另一种方法是从具有最大误差平方和的簇中选择一个替补的质心。这将分裂簇并降低聚类的总误差平方和。如果有多个空簇,则该过程重复多次。
用于图像分割
输入参数为像素的横、纵坐标以及R、G、B三色值共5个参数,参数归一化后计算各个像素点与各个初始聚类像素点(假设有三个)之间五维向量的欧式距离;然后将其归类到与其距离最小的那个初始聚类中,一遍处理完之后所有点都被归类到三个集合里,如果本次归类完成后与上次相比归类情况发生变化的情况小于预设值,算法就收敛了,否则重新计算三个集合中各个集合的重心,三个中心作为聚类像素点,进行下一次归类。
借用孙兴华老师的代码展示下,代码如下:
double Distance_of_location_and_color(CImagePoint pixel_1, RGB_TRIPLE color_1, CImagePoint pixel_2, RGB_TRIPLE color_2, long image_height, long image_width) { double first_1 = double(pixel_1.m_row) / double(image_height); double second_1 = double(pixel_1.m_column) / double(image_width); double third_1 = double(color_1.m_Red) / 255.0; double fourth_1 = double(color_1.m_Green) / 255.0; double fifth_1 = double(color_1.m_Blue) / 255.0; double first_2 = double(pixel_2.m_row) / double(image_height); double second_2 = double(pixel_2.m_column) / double(image_width); double third_2 = double(color_2.m_Red) / 255.0; double fourth_2 = double(color_2.m_Green) / 255.0; double fifth_2 = double(color_2.m_Blue) / 255.0; return sqrt((first_1 - first_2) * (first_1 - first_2) + (second_1 - second_2) * (second_1 - second_2) + (third_1 - third_2) * (third_1 - third_2) + (fourth_1 - fourth_2) * (fourth_1 - fourth_2) + (fifth_1 - fifth_2) * (fifth_1 - fifth_2)); } int Update_index_into_clusters(CImagePoint current_pixel, RGB_TRIPLE current_color, CTArray< CImagePoint> centers_of_pixel, CTArray< RGB_TRIPLE > centers_of_color, long image_height, long image_width) { long dimension = centers_of_pixel.GetDimension(); CTArray< double > array_of_distances(dimension); double current_distance; int current_index; for (int index = 0; index < dimension; index++) { array_of_distances[index] = Distance_of_location_and_color(current_pixel, current_color, centers_of_pixel[index], centers_of_color[index], image_height, image_width); if (index == 0) { current_distance = array_of_distances[index]; current_index = index; } else { if (array_of_distances[index] < current_distance) { current_distance = array_of_distances[index]; current_index = index; } } } return current_index; } void Update_centers_of_clusters(const CTMatrix< RGB_TRIPLE >& color_image, const CTMatrix< int >& cluster_result, int number_of_clusters, CTArray< CImagePoint >& centers_of_pixel, CTArray< RGB_TRIPLE >& centers_of_color) { long image_height = color_image.Get_height(); long image_width = color_image.Get_width(); CTArray< int > scale_of_clusters(number_of_clusters); CTMatrix< double > sum_of_clusters(number_of_clusters, 5); for (int index = 0; index < number_of_clusters; index++) { scale_of_clusters[index] = 0; for (int sub_index = 0; sub_index < 5; sub_index++) { sum_of_clusters[index][sub_index] = 0; } } for (int row = 0; row < image_height; row++) for (int column = 0; column < image_width; column++) { int current_index = cluster_result[row][column]; scale_of_clusters[current_index] ++; sum_of_clusters[current_index][0] += row; sum_of_clusters[current_index][1] += column; sum_of_clusters[current_index][2] += color_image[row][column].m_Red; sum_of_clusters[current_index][3] += color_image[row][column].m_Green; sum_of_clusters[current_index][4] += color_image[row][column].m_Blue; } for (int index = 0; index < number_of_clusters; index++) if (scale_of_clusters[index] != 0) { centers_of_pixel[index].m_row = long(sum_of_clusters[index][0] / scale_of_clusters[index]); centers_of_pixel[index].m_column = long(sum_of_clusters[index][1] / scale_of_clusters[index]); centers_of_color[index].m_Red = BYTE(sum_of_clusters[index][2] / scale_of_clusters[index]); centers_of_color[index].m_Green = BYTE(sum_of_clusters[index][3] / scale_of_clusters[index]); centers_of_color[index].m_Blue = BYTE(sum_of_clusters[index][4] / scale_of_clusters[index]); } return; } bool Are_centers_same(const CTArray< CImagePoint>& centers_of_pixel, const CTArray< RGB_TRIPLE >& centers_of_color, const CTArray< CImagePoint>& copies_of_pixel, const CTArray< RGB_TRIPLE >& copies_of_color) { bool is_same = true; long dimension = centers_of_pixel.GetDimension(); for (int index = 0; index < dimension; index++) { if (centers_of_pixel[index] != copies_of_pixel[index] || centers_of_color[index].m_Red != copies_of_color[index].m_Red || centers_of_color[index].m_Green != copies_of_color[index].m_Green || centers_of_color[index].m_Blue != copies_of_color[index].m_Blue) { is_same = false; break; } } return is_same; } // [ ********** ] ........................................................ // [ K 均值聚类 ] ........................................................ // [ ********** ] ........................................................ CTMatrix< int > CImageColorProcess::K_means_clustering(const CTMatrix< RGB_TRIPLE >& color_image, int number_of_clusters) { long image_height = color_image.Get_height(); long image_width = color_image.Get_width(); CTMatrix< int > cluster_result(image_height, image_width); CTArray< CImagePoint> centers_of_pixel(number_of_clusters); CTArray< RGB_TRIPLE > centers_of_color(number_of_clusters); for (int index = 0; index < number_of_clusters; index++) { int row = image_height * (index + 1) / (number_of_clusters + 1); int column = image_width * (index + 1) / (number_of_clusters + 1); centers_of_pixel[index] = CImagePoint(row, column); centers_of_color[index] = color_image[row][column]; } CTArray< CImagePoint> copies_of_pixel; CTArray< RGB_TRIPLE > copies_of_color; long iteration = 0; do { iteration++; copies_of_pixel = centers_of_pixel; copies_of_color = centers_of_color; for (int row = 0; row < image_height; row++) for (int column = 0; column < image_width; column++) { cluster_result[row][column] = Update_index_into_clusters(CImagePoint(row, column), color_image[row][column], centers_of_pixel, centers_of_color, image_height, image_width); } Update_centers_of_clusters(color_image, cluster_result, number_of_clusters, centers_of_pixel, centers_of_color); } while (!Are_centers_same(centers_of_pixel, centers_of_color, copies_of_pixel, copies_of_color) && iteration < 1000); return cluster_result; }
调用方法
long number_of_clusters = 4; CTArray< RGB_TRIPLE > template_of_display(number_of_clusters); template_of_display[0] = RGB_TRIPLE(255, 255, 0); template_of_display[1] = RGB_TRIPLE(0, 175, 175); template_of_display[2] = RGB_TRIPLE(100, 0, 100); template_of_display[3] = RGB_TRIPLE(0, 0, 0); int nW(0), nH(0), nCount(0); m_SrcView.GetImgInfo(0, nW, nH, nCount); //申请原始的图像数据内存 int nSrcLine = LINEWIDTH(nW*nCount); LPBYTE lpSrcBit = new BYTE[nSrcLine*nH]; m_SrcView.GetImgData(0, nW, nH, lpSrcBit); //申请目标图像数据内存 int nDstCount = 24; int nDstLine = LINEWIDTH(nW*nCount); int nLineBytes = LINEWIDTH(nW*nCount); LPBYTE lpDstBit = new BYTE[nLineBytes*nH]; CTMatrix<RGB_TRIPLE> myimage(nH,nW); for (int i = 0; i < nH; i++) { for (int j = 0; j < nW; j++) { myimage[i][j].m_Red = lpSrcBit[i*nSrcLine + 3 * j]; myimage[i][j].m_Green = lpSrcBit[i*nSrcLine + 3 * j+1]; myimage[i][j].m_Blue = lpSrcBit[i*nSrcLine + 3 * j+2]; } } //调用处理函数 CTMatrix< int > cluster_results = CImageColorProcess::K_means_clustering(myimage, number_of_clusters); int image_height = nH; int image_width = nW; CTMatrix< RGB_TRIPLE > display_image(image_height, image_width); for (int row = 0; row < image_height; row++) { for (int column = 0; column < image_width; column++) { display_image[row][column] = template_of_display[cluster_results[row][column]]; lpDstBit[row*nDstLine + column * 3] = display_image[row][column].m_Red; lpDstBit[row*nDstLine + column * 3+1] = display_image[row][column].m_Green; lpDstBit[row*nDstLine + column * 3+2] = display_image[row][column].m_Blue; } } //显示图像数据 m_DstView.SetImgData(lpDstBit, nW, nH, nDstCount); m_DstView.Invalidate(); delete[] lpSrcBit; delete[] lpDstBit;
效果图发几张
原图 三个聚类点 四个聚类点
版权声明: