k_means算法的C++实现
首先画出k_means算法的流程图:
具体代码,C++实现:
#include <iostream> #include <fstream> #include <cmath> #include <ctime> using namespace std; const int N = 10000; typedef struct { double x1,x2; int flag; }oneNode; oneNode node[N]; int k,n; void DataInput(); void SetInitial(oneNode *p,oneNode *q); void SetQzero(oneNode &qi); bool GetExist(int xm,int *sb,int i); void DataIteration(oneNode *p,oneNode *q,double &E1,double &E3); double GetRange(oneNode &node1,oneNode &node2); bool JudgeCondition(oneNode *p,oneNode *q); void DataOutput(oneNode *p,double E1,double E3); int main(){ int E2=0; double E1 = 0,E3; n = 0; DataInput(); //将txt文件中记录读入,存储到oneNode列表中 cout<<"请输入新的聚类数目:"; cin>>k; oneNode *p = new oneNode[k]; oneNode *q = new oneNode[k]; SetInitial(p,q); //初始化开始的k个聚类中心 DataIteration(p,q,E1,E3); //循环处理每条记录,获得新的聚类中心 DataOutput(p,E1,E3); //输出程序运行结果 delete[] q; delete[] p; return 0; } void DataInput(){ ifstream ifs; //通过txt文件,读入所有记录,将记录存储到oneNode列表中 ifs.open("D:\k_means\\aa11.txt"); if (!ifs.is_open()) { cout<<"文件读取失败"<<endl; exit(0); } while (!ifs.eof()) //将txt文件中记录,存储到node[]中 { ifs>>node[n].x1; ifs>>node[n].x2; n++; } } void SetInitial(oneNode *p,oneNode *q){ int xm; int *sb = new int[k]; srand((unsigned)time(0)); for (int i = 0;i < k;i++) { do { xm = rand()%n; } while (GetExist(xm,sb,i)); sb[i] = xm; p[i] = node[xm]; SetQzero(q[i]); } delete[] sb; } void SetQzero(oneNode &qi){ qi.x1 = 0; qi.x2 = 0; qi.flag = 0; } bool GetExist(int xm,int *sb,int i){ bool b = false; for (int j = 0;j < i;j++) { if (sb[j] == xm) { b = true; break; } } return b; } void DataIteration(oneNode *p,oneNode *q,double &E1,double &E3){ int i,j,E2 = 0; bool b; do { E3 = 0; for(i = 0;i < n;i++){ double mm = GetRange(node[i],p[0]); int m = 0; for (j = 1;j < k;j++) { double nn = GetRange(node[i],p[j]); if (mm > nn) { mm = nn; m = j; } } node[i].flag = m; E3 += GetRange(node[i],p[m]); q[m].x1 += node[i].x1; q[m].x2 += node[i].x2; q[m].flag += 1; } if (E2 == 0) { E1 = E3; }else if(E2 > 0){ if (E1 <= E3) { break; }else{ E1 = E3; } } E2++; for (i = 0;i < k;i++) { if (q[i].flag != 0) { q[i].x1 = q[i].x1/q[i].flag; q[i].x2 = q[i].x2/q[i].flag; } } b = JudgeCondition(p,q); for (i = 0;i < k;i++) { p[i] = q[i]; SetQzero(q[i]); } } while (b); } double GetRange(oneNode &node1,oneNode &node2){ return pow(node1.x1-node2.x1,2)+pow(node1.x2-node2.x2,2); } bool JudgeCondition(oneNode *p,oneNode *q){ bool bo = false; for (int i = 0;i < k;i++) { if (pow(p[i].x1-q[i].x1,2)+pow(p[i].x2-q[i].x2,2) >= 5.0) { bo = true; break; } } return bo; } void DataOutput(oneNode *p,double E1,double E3){ if (E1 != E3){ cout<<"操作失败"<<endl; }else{ cout<<"本次总体误差是:"<<E1<<endl; ofstream ofile("D://k_means/aa11result.txt"); for (int i = 0;i < k;i++) { ofile<<"第"<<i+1<<"个聚类中心是:"<<"("<<p[i].x1<<","<<p[i].x2<<")"<<endl<<"属于第"<<i+1<<"个聚类的所有点是:"<<endl; for (int j = 0;j < n;j++) { if (node[j].flag == i) { ofile<<node[j].x1<<" "<<node[j].x2<<endl; } } } } }