k_means算法C++实现,改为面向对象

画的类图如下:

分为两个类kMeans和pointClass类,kMeans两个成员变量:节点总个数和最终聚类个数。pointClass类包含结构体point。

具体代码如下:

kMeans.h

class kMeans{
protected:
    int numOfPoint, numOfCenter;
public:
    kMeans();
    void k_means();
};
View Code

kMeans.cpp

#include "kMeans.h"
#include "point.h"
#include <iostream>
using namespace std;

kMeans::kMeans()
{
    cout << "分别输入数据点和最终聚类的个数: ";
    cin >> numOfPoint >> numOfCenter;
}

void kMeans::k_means()
{
    pointClass pc(numOfPoint, numOfCenter);
    pc.InitCenter(numOfPoint, numOfCenter);
    bool b = true;
    while (b){
        pc.setPoint(numOfPoint, numOfCenter);
        if (pc.getError()){
            break;
        }
        pc.getNewCenter(numOfPoint, numOfCenter);
        b = pc.IsEnd(numOfCenter);
        pc.resetCenterOld(numOfCenter);
    }
    if (b) {
        cout << "聚类操作无法完成!!" << endl;
    }else{
        pc.ExportData(numOfPoint, numOfCenter);
    }
}
View Code

point.h

struct point{
    double x1, x2, x3, x4;
    int flag;
};

class pointClass{
protected:
    double errorOld, errorNew;
    point *pList, *centerListNew, *centerListOld;
    double GetDistance(point point1, point point2);
    bool GetExist(int xm, int centerList[], int n);
public:
    pointClass (int numOfPoint, int numOfCenter);
    void InitCenter (int numOfPoint, int numOfCenter);
    void setPoint (int numOfPoint, int numOfCenter);
    bool getError();
    void getNewCenter (int numOfPoint, int numOfCenter);
    bool IsEnd (int numOfCenter);
    void ExportData (int numOfPoint, int numOfCenter);
    void resetCenterOld (int numOfCenter);
    ~pointClass ();
};
View Code

point.cpp

#include <iostream>
#include <fstream>
#include <cmath>
#include <ctime>
#include "point.h"
using namespace std;

double pointClass::GetDistance(point point1, point point2)
{
    return pow(pow(point1.x1 - point2.x1, 2) + pow(point1.x2 - point2.x2, 2) + pow(point1.x3 - point2.x3, 2) + pow(point1.x4 - point2.x4, 2), 0.5);
}

bool pointClass::GetExist(int xm, int CentIndex[], int n)
{
    bool b = false;
    for (int i = 0; i < n; i++){
        if (xm == CentIndex[i]){
            b = true;
            break;
        }
    }
    return b;
}

pointClass::pointClass(int numOfPoint, int numOfCenter)
{
    pList = new point[numOfPoint];
    centerListOld = new point[numOfCenter];
    centerListNew = new point[numOfCenter];
    ifstream ifile("D:\\IrisData.txt");
    if (!ifile.is_open()){
        cerr << "file" << endl;
        exit(0);
    }
    int i = 0;
    while (i < numOfPoint){
        ifile >> pList[i].x1 >> pList[i].x2 >> pList[i].x3 >> pList[i].x4;
        pList[i].flag = 0;
        i++;
    }
    ifile.close();

    errorNew = 0, errorOld = 0;
    for (i = 0; i < numOfCenter; i++){
        centerListOld[i].x1 = 0;
        centerListOld[i].x2 = 0;
        centerListOld[i].x3 = 0;
        centerListOld[i].x4 = 0;
        centerListNew[i].x1 = 0;
        centerListNew[i].x2 = 0;
        centerListNew[i].x3 = 0;
        centerListNew[i].x4 = 0;
        centerListNew[i].flag = 0;
        centerListOld[i].flag = 0;
    }
}

void pointClass::InitCenter(int numOfPoint, int numOfCenter)
{
    int xm, i;
    int *CenterIndex = new int[numOfCenter];
    srand((unsigned)time(0));
    for (i = 0; i < numOfCenter; i++){
        do {
            xm = rand() % numOfPoint;
        } while (GetExist(xm, CenterIndex, i));
        CenterIndex[i] = xm;
    }

    for (i = 0; i < numOfCenter; i++){
        centerListOld[i] = pList[CenterIndex[i]];
    }
}


void pointClass::setPoint(int numOfPoint, int numOfCenter)
{
    errorNew = 0;
    for (int i = 0; i < numOfPoint; i++){
        int flagi = 0;
        double distance = GetDistance(pList[i], centerListOld[0]);
        for (int j = 1; j < numOfCenter; j++){
            double tmp = GetDistance(pList[i], centerListOld[j]);
            if (tmp < distance){
                tmp = distance;
                flagi = j;
            }
        }
        pList[i].flag = flagi;
        errorNew = GetDistance(pList[i], centerListOld[flagi]);
    }
}

bool pointClass::getError()
{
    bool b = false;
    if (errorOld != 0 && errorNew >= errorOld){
        b = true;
    }
    return b;
}

void pointClass::getNewCenter(int numOfPoint, int numOfCenter)
{
    for (int i = 0; i < numOfCenter; i++){
        centerListNew[i].x1 = 0;
        centerListNew[i].x2 = 0;
        centerListNew[i].x3 = 0;
        centerListNew[i].x4 = 0;
        centerListNew[i].flag = 0;
    }
    for (int i = 0; i < numOfPoint; i++){
        centerListNew[pList[i].flag].x1 += pList[i].x1;
        centerListNew[pList[i].flag].x2 += pList[i].x2;
        centerListNew[pList[i].flag].x3 += pList[i].x3;
        centerListNew[pList[i].flag].x4 += pList[i].x4;
        centerListNew[pList[i].flag].flag++;
    }
    for (int i = 0; i < numOfCenter; i++){
        centerListNew[i].x1 = centerListNew[i].x1 / centerListNew[i].flag;
        centerListNew[i].x2 = centerListNew[i].x2 / centerListNew[i].flag;
        centerListNew[i].x3 = centerListNew[i].x3 / centerListNew[i].flag;
        centerListNew[i].x4 = centerListNew[i].x4 / centerListNew[i].flag;
        centerListNew[i].flag = 0;
    }
}

void pointClass::resetCenterOld(int numOfCenter)
{
    for (int i = 0; i < numOfCenter; i++){
        centerListOld[i] = centerListNew[i];
    }
}

bool pointClass::IsEnd(int numOfCenter)
{
    bool b = false;
    for (int i = 0; i < numOfCenter; i++){
        if (GetDistance(centerListNew[i],centerListOld[i]) > 1){
            b = true;
            break;
        }
    }
    return b;
}

void pointClass::ExportData(int numOfPoint, int numOfCenter)
{
    ofstream ofile("D:\\kMeansResult.txt");
    cout << "本次误差是:" << errorNew << endl;
    ofile << "本次误差是:" << errorNew << endl;
    for (int j = 0; j < numOfCenter; j++){
        ofile << "" << j+1 <<"类:" << endl;
        for (int i = 0; i < numOfPoint; i++){
            if (pList[i].flag == j){
                ofile << pList[i].x1 << " " << pList[i].x2 << " " << pList[i].x3 << " " <<pList[i].x4 << endl;
            }
        }
    }
}


pointClass::~pointClass()
{
    delete[] pList;
    delete[] centerListOld;
    delete[] centerListNew;
}
View Code

主函数main.cpp

#include <iostream>
#include "kMeans.h"
using namespace std;

int main()
{
    kMeans km;
    km.k_means();
    return 0;
}
View Code

 

posted on 2013-10-03 22:30  张三的哥哥  阅读(984)  评论(0编辑  收藏  举报