kmean算法C++实现

kmean均值算法是一种最常见的聚类算法。算法实现简单,效果也比较好。kmean算法把n个对象划分成指定的k个簇,每个簇中所有对象的均值的平均值为该簇的聚点(中心)。

k均值算法有如下五个步骤:

  1. 随机生成最初始k个簇心。可以从样本中随机选择,也可以根据样本中每个特征的取值特点随机生成。
  2. 对每个样本计算到每个簇心的欧式距离,将样本划分到欧氏距离最小的簇心(聚点)。
  3. 对划分到同一个簇心(聚点)的样本计算平均值,用均值更新簇心(聚点)
  4. 若某些簇心(聚点)发生变化,转到2;若所有的聚点都没有变化,转5
  5. 输出划分结果
  1 #include <vector>
  2 #include <cassert>
  3 #include <iostream>
  4 #include <cmath>
  5 #include <fstream>
  6 #include <climits>
  7 #include <ctime>
  8 #include <iomanip>
  9 
 10 using namespace std;
 11 namespace terse {
 12 class Kmeans {
 13 private:
 14     vector<vector<double>> m_dataSet;
 15     int m_k;
 16     vector<int> m_clusterResult;         // result of cluster
 17     vector<vector<double>> m_cluserCent; //center of k clusters
 18 
 19 private:
 20     vector<string> split(const string& s, string pattern) {
 21         vector<string> res;
 22         size_t start = 0;
 23         size_t end = 0;
 24         while (start < s.size()) {
 25             end = s.find_first_of(pattern, start);
 26             if (end == string::npos) {
 27                 res.push_back(s.substr(start, end - start - 1));
 28                 return res;
 29             }
 30             res.push_back(s.substr(start, end - start));
 31             start = end + 1;
 32         }
 33         return res;
 34     }
 35 
 36     void loadDataSet(const char* fileName) {
 37         ifstream dataFile(fileName);
 38         if (!dataFile.is_open()) {
 39             cerr << "open file " << fileName << "failed!\n";
 40             return;
 41         }
 42         string tmpstr;
 43         vector<double> data;
 44         while (!dataFile.eof()) {
 45             data.clear();
 46             tmpstr.clear();
 47             getline(dataFile, tmpstr);
 48             vector<string> tmp = split(tmpstr, ",");
 49             for (string str : tmp) {
 50                 data.push_back(stod(str));
 51             }
 52             this->m_dataSet.push_back(data);
 53         }
 54         dataFile.close();
 55     }
 56 
 57     //compute Euclidean distance of two vector
 58     double distEclud(vector<double>& v1, vector<double>& v2) {
 59         assert(v1.size() == v2.size());
 60         double dist = 0;
 61         for (size_t i = 0; i < v1.size(); i++) {
 62             dist += (v1[i] - v2[i]) * (v1[i] - v2[i]);
 63         }
 64         return sqrt(dist);
 65     }
 66 
 67     void generateRandCent() {
 68         int numOfFeats = this->m_dataSet[0].size();
 69         size_t numOfSamples = this->m_dataSet.size();
 70 
 71         //first:min second:max
 72         vector<pair<double, double>> minMaxOfFeat(numOfFeats);
 73         for (int i = 0; i < numOfFeats; i++) {
 74             minMaxOfFeat[i].first = this->m_dataSet[0][i];
 75             minMaxOfFeat[i].second = this->m_dataSet[0][i];
 76         }
 77         for (size_t i = 1; i < numOfSamples; i++) {
 78             for (int j = 0; j < numOfFeats; j++) {
 79                 if (this->m_dataSet[i][j] > minMaxOfFeat[j].second) {
 80                     minMaxOfFeat[j].second = this->m_dataSet[i][j];
 81                 }
 82                 if (this->m_dataSet[i][j] < minMaxOfFeat[j].first) {
 83                     minMaxOfFeat[j].first = this->m_dataSet[i][j];
 84                 }
 85             }
 86         }
 87         srand(time(NULL));
 88         for (int i = 0; i < this->m_k; i++) {
 89             for (int j = 0; j < numOfFeats; j++) {
 90                 this->m_cluserCent[i][j] = minMaxOfFeat[j].first
 91                         + (minMaxOfFeat[j].second - minMaxOfFeat[j].first)
 92                                 * (rand() / (double) RAND_MAX);
 93             }
 94         }
 95 
 96     }
 97 
 98     void printClusterCent(int iter) {
 99         int m = this->m_cluserCent.size();
100         int n = this->m_cluserCent[0].size();
101         cout << "iter =  " << iter;
102         for (int i = 0; i < m; i++) {
103             cout << " {";
104             for (int j = 0; j < n; j++) {
105                 cout << this->m_cluserCent[i][j] << ",";
106             }
107             cout << "};";
108         }
109         cout << endl;
110     }
111 
112     void writeResult(const char* fileName = "res.txt") {
113         ofstream fout(fileName);
114         if (!fout.is_open()) {
115             cerr << "open file " << fileName << "failed!";
116             return;
117         }
118         for (size_t i = 0; i < this->m_dataSet.size(); i++) {
119             for (size_t j = 0; j < this->m_dataSet[0].size(); j++) {
120                 fout << this->m_dataSet[i][j] << "\t";
121             }
122             fout << setprecision(5) << this->m_clusterResult[i] << "\n";
123         }
124         fout.close();
125     }
126 
127 public:
128     Kmeans(int k, const char* fileName) {
129         this->m_k = k;
130         this->loadDataSet(fileName);
131         this->m_clusterResult.reserve(this->m_dataSet.size());
132         this->m_cluserCent = vector<vector<double>>(k,
133                 vector<double>(this->m_dataSet[0].size()));
134         generateRandCent();
135     }
136 
137     Kmeans(int k, vector<vector<double>>& data) {
138         this->m_k = k;
139         this->m_dataSet = data;
140         this->m_clusterResult.reserve(this->m_dataSet.size());
141         this->m_cluserCent = vector<vector<double>>(k,
142                 vector<double>(this->m_dataSet[0].size()));
143         generateRandCent();
144     }
145 
146     //verbose = 1,printClusterCent();
147     void kmeansCluster(int verbose = 1) {
148         int iter = 0;
149         bool isClusterChanged = true;
150         while (isClusterChanged) {
151             isClusterChanged = false;
152             //step 1: find the nearest centroid of each point
153             int numOfFeats = this->m_dataSet[0].size();
154             size_t numOfSamples = this->m_dataSet.size();
155             for (size_t i = 0; i < numOfSamples; i++) {
156                 int minIndex = -1;
157                 double minDist = INT_MAX;
158                 for (int j = 0; j < this->m_k; j++) {
159                     double dist = distEclud(this->m_cluserCent[j],
160                             m_dataSet[i]);
161                     if (dist < minDist) {
162                         minDist = dist;
163                         minIndex = j;
164                     }
165                 }
166                 if (m_clusterResult[i] != minIndex) {
167                     isClusterChanged = true;
168                     m_clusterResult[i] = minIndex;
169                 }
170             }
171 
172             //step 2: update cluster center
173             vector<size_t> cnt(this->m_k, 0);
174             this->m_cluserCent = vector<vector<double>>(this->m_k,
175                     vector<double>(numOfFeats, 0.0));
176             for (size_t i = 0; i < numOfSamples; i++) {
177                 for (int j = 0; j < numOfFeats; j++) {
178                     this->m_cluserCent[this->m_clusterResult[i]][j] +=
179                             this->m_dataSet[i][j];
180                 }
181                 cnt[this->m_clusterResult[i]]++;
182             }
183             // mean of the vector belong to a cluster
184             for (int i = 0; i < this->m_k; i++) {
185                 for (int j = 0; j < numOfFeats; j++) {
186                     this->m_cluserCent[i][j] /= cnt[i];
187                 }
188             }
189             if (verbose)
190                 printClusterCent(iter++);
191         }
192         writeResult();
193     }
194 };
195 
196 };
197 
198 int main(){
199     terse::Kmeans kmeans(4,"datafile.txt");
200     kmeans.kmeansCluster();
201     return 0;
202 }
203 /*namespace terse*/

 

posted @ 2017-04-23 22:39  wxquare  阅读(512)  评论(0编辑  收藏  举报