k_means算法C++实现,改为面向对象
画的类图如下:
分为两个类kMeans和pointClass类,kMeans两个成员变量:节点总个数和最终聚类个数。pointClass类包含结构体point。
具体代码如下:
kMeans.h
class kMeans{ protected: int numOfPoint, numOfCenter; public: kMeans(); void k_means(); };
kMeans.cpp
#include "kMeans.h" #include "point.h" #include <iostream> using namespace std; kMeans::kMeans() { cout << "分别输入数据点和最终聚类的个数: "; cin >> numOfPoint >> numOfCenter; } void kMeans::k_means() { pointClass pc(numOfPoint, numOfCenter); pc.InitCenter(numOfPoint, numOfCenter); bool b = true; while (b){ pc.setPoint(numOfPoint, numOfCenter); if (pc.getError()){ break; } pc.getNewCenter(numOfPoint, numOfCenter); b = pc.IsEnd(numOfCenter); pc.resetCenterOld(numOfCenter); } if (b) { cout << "聚类操作无法完成!!" << endl; }else{ pc.ExportData(numOfPoint, numOfCenter); } }
point.h
struct point{ double x1, x2, x3, x4; int flag; }; class pointClass{ protected: double errorOld, errorNew; point *pList, *centerListNew, *centerListOld; double GetDistance(point point1, point point2); bool GetExist(int xm, int centerList[], int n); public: pointClass (int numOfPoint, int numOfCenter); void InitCenter (int numOfPoint, int numOfCenter); void setPoint (int numOfPoint, int numOfCenter); bool getError(); void getNewCenter (int numOfPoint, int numOfCenter); bool IsEnd (int numOfCenter); void ExportData (int numOfPoint, int numOfCenter); void resetCenterOld (int numOfCenter); ~pointClass (); };
point.cpp
#include <iostream> #include <fstream> #include <cmath> #include <ctime> #include "point.h" using namespace std; double pointClass::GetDistance(point point1, point point2) { return pow(pow(point1.x1 - point2.x1, 2) + pow(point1.x2 - point2.x2, 2) + pow(point1.x3 - point2.x3, 2) + pow(point1.x4 - point2.x4, 2), 0.5); } bool pointClass::GetExist(int xm, int CentIndex[], int n) { bool b = false; for (int i = 0; i < n; i++){ if (xm == CentIndex[i]){ b = true; break; } } return b; } pointClass::pointClass(int numOfPoint, int numOfCenter) { pList = new point[numOfPoint]; centerListOld = new point[numOfCenter]; centerListNew = new point[numOfCenter]; ifstream ifile("D:\\IrisData.txt"); if (!ifile.is_open()){ cerr << "file" << endl; exit(0); } int i = 0; while (i < numOfPoint){ ifile >> pList[i].x1 >> pList[i].x2 >> pList[i].x3 >> pList[i].x4; pList[i].flag = 0; i++; } ifile.close(); errorNew = 0, errorOld = 0; for (i = 0; i < numOfCenter; i++){ centerListOld[i].x1 = 0; centerListOld[i].x2 = 0; centerListOld[i].x3 = 0; centerListOld[i].x4 = 0; centerListNew[i].x1 = 0; centerListNew[i].x2 = 0; centerListNew[i].x3 = 0; centerListNew[i].x4 = 0; centerListNew[i].flag = 0; centerListOld[i].flag = 0; } } void pointClass::InitCenter(int numOfPoint, int numOfCenter) { int xm, i; int *CenterIndex = new int[numOfCenter]; srand((unsigned)time(0)); for (i = 0; i < numOfCenter; i++){ do { xm = rand() % numOfPoint; } while (GetExist(xm, CenterIndex, i)); CenterIndex[i] = xm; } for (i = 0; i < numOfCenter; i++){ centerListOld[i] = pList[CenterIndex[i]]; } } void pointClass::setPoint(int numOfPoint, int numOfCenter) { errorNew = 0; for (int i = 0; i < numOfPoint; i++){ int flagi = 0; double distance = GetDistance(pList[i], centerListOld[0]); for (int j = 1; j < numOfCenter; j++){ double tmp = GetDistance(pList[i], centerListOld[j]); if (tmp < distance){ tmp = distance; flagi = j; } } pList[i].flag = flagi; errorNew = GetDistance(pList[i], centerListOld[flagi]); } } bool pointClass::getError() { bool b = false; if (errorOld != 0 && errorNew >= errorOld){ b = true; } return b; } void pointClass::getNewCenter(int numOfPoint, int numOfCenter) { for (int i = 0; i < numOfCenter; i++){ centerListNew[i].x1 = 0; centerListNew[i].x2 = 0; centerListNew[i].x3 = 0; centerListNew[i].x4 = 0; centerListNew[i].flag = 0; } for (int i = 0; i < numOfPoint; i++){ centerListNew[pList[i].flag].x1 += pList[i].x1; centerListNew[pList[i].flag].x2 += pList[i].x2; centerListNew[pList[i].flag].x3 += pList[i].x3; centerListNew[pList[i].flag].x4 += pList[i].x4; centerListNew[pList[i].flag].flag++; } for (int i = 0; i < numOfCenter; i++){ centerListNew[i].x1 = centerListNew[i].x1 / centerListNew[i].flag; centerListNew[i].x2 = centerListNew[i].x2 / centerListNew[i].flag; centerListNew[i].x3 = centerListNew[i].x3 / centerListNew[i].flag; centerListNew[i].x4 = centerListNew[i].x4 / centerListNew[i].flag; centerListNew[i].flag = 0; } } void pointClass::resetCenterOld(int numOfCenter) { for (int i = 0; i < numOfCenter; i++){ centerListOld[i] = centerListNew[i]; } } bool pointClass::IsEnd(int numOfCenter) { bool b = false; for (int i = 0; i < numOfCenter; i++){ if (GetDistance(centerListNew[i],centerListOld[i]) > 1){ b = true; break; } } return b; } void pointClass::ExportData(int numOfPoint, int numOfCenter) { ofstream ofile("D:\\kMeansResult.txt"); cout << "本次误差是:" << errorNew << endl; ofile << "本次误差是:" << errorNew << endl; for (int j = 0; j < numOfCenter; j++){ ofile << "第" << j+1 <<"类:" << endl; for (int i = 0; i < numOfPoint; i++){ if (pList[i].flag == j){ ofile << pList[i].x1 << " " << pList[i].x2 << " " << pList[i].x3 << " " <<pList[i].x4 << endl; } } } } pointClass::~pointClass() { delete[] pList; delete[] centerListOld; delete[] centerListNew; }
主函数main.cpp
#include <iostream> #include "kMeans.h" using namespace std; int main() { kMeans km; km.k_means(); return 0; }