meanshift聚类的实现
参见http://blog.csdn.net/u014568921/article/details/45197027
// meanshift-cluster.cpp : 定义控制台应用程序的入口点。 // #include "stdafx.h" #include<iostream> #include<vector> #include<assert.h> #include<cstdlib> #include<time.h> using namespace std; #define MSTYPE double class meanshift { private: struct MSData { vector<MSTYPE>data; //unsigned int dim; MSData(unsigned int d) { //dim = d; data.resize(d); } }; vector<MSData>dataset; double kernel_bandwidth; MSData shiftvec(MSData vec) { MSData shiftvector(vec.data.size()); double total_weight = 0; for (int i = 0; i<dataset.size(); i++){ MSData temp = dataset[i]; double distance = euclidean_distance(vec, temp); double weight = gaussian_kernel(distance); for (int j = 0; j<shiftvector.data.size(); j++){ shiftvector.data[j] += temp.data[j] * weight; } total_weight += weight; } for (int i = 0; i<shiftvector.data.size(); i++){ shiftvector.data[i] /= total_weight; } return shiftvector; } double gaussian_kernel(double distance){ double temp = exp(-(distance*distance) / (kernel_bandwidth)); return temp; } double euclidean_distance(const MSData &data1, const MSData &data2) { assert(data1.data.size() == data2.data.size()); double sum = 0; for (int i = 0; i<data1.data.size(); i++){ sum += (data1.data[i] - data2.data[i]) * (data1.data[i] - data2.data[i]); } return sqrt(sum); } public: meanshift(double kernel_bandwidth) :kernel_bandwidth(kernel_bandwidth) { time_t t; srand(time(&t)); } vector<MSData> apply() { vector<int> stop_moving; stop_moving.resize(dataset.size()); vector<MSData> shifted_points = dataset; double max_shift_distance; do { max_shift_distance = 0; for (int i = 0; i<shifted_points.size(); i++){ if (!stop_moving[i]) { MSData point_new = shiftvec(shifted_points[i]); double shift_distance = euclidean_distance(point_new, shifted_points[i]); if (shift_distance > max_shift_distance){ max_shift_distance = shift_distance; } #define EPSILON 0.00000001 if (shift_distance <= EPSILON) { stop_moving[i] = 1; } shifted_points[i] = point_new; } } printf("max_shift_distance: %f\n", max_shift_distance); } while (max_shift_distance > EPSILON); for (int i = 0; i < dataset.size(); i++) { cout << "原始坐标 (" << dataset[i].data[0] << "," << dataset[i].data[1] << ") 滑动到 (" << shifted_points[i].data[0] << "," << shifted_points[i].data[1] << ")" << endl; } return shifted_points; } void generatedata(int datanums,vector<int>&span) { for (int i = 0; i < datanums; i++) { MSData dd(span.size()); for (int j = 0; j < span.size(); j++) { dd.data[j] = double(rand()) / (RAND_MAX + 1.0)*span[j]; } dataset.push_back(dd); } } }; int _tmain(int argc, _TCHAR* argv[]) { meanshift ms(4); vector<int>span; span.push_back(20); span.push_back(20); ms.generatedata(100, span); ms.apply(); return 0; }
结果如下图
版权声明: