K-means (PRML) in C++

原始数据

#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <string>
#include <algorithm>
#include <numeric>
#include <cmath>
#include <limits>

template <class T>
void ReadDataFromFile(const std::string &filename, std::vector<std::vector<T> > &vv_data) {
    std::ifstream vm_info(filename.c_str());
    T x, y;
    std::vector<T> v_data;

    while(!vm_info.eof()) {
        v_data.clear();
        vm_info >> x >> y;
        v_data.push_back(x);
        v_data.push_back(y);
        vv_data.push_back(v_data);
    }
    vm_info.close();
}

template <class T>
void Display2DVector(std::vector<std::vector<T> > &vv) {
    for(size_t i=0;i<vv.size();++i) {
        for(typename::std::vector<T>::const_iterator it=vv.at(i).begin();it!=vv.at(i).end();++it) {
            std::cout<<*it<<" ";
        }
        std::cout<<"\n";
    }
    std::cout<<"--------the total of the 2DVector is "<<vv.size()<<std::endl;
}

template <class T>
void AddIndicator(std::vector<std::vector<T> > &vv, const int &k) {
    for(size_t i=0; i<vv.size(); ++i) {
        for(size_t j=0; j<k; ++j) {
            vv.at(i).push_back(0);
        }
    }
}

template <class T1, class T2>
void UpdateIndicator(std::vector<std::vector<T1> > &vv, const std::vector<T2> &u, const int &k) {
    for(size_t i=0; i<vv.size(); ++i) {
        double dis=std::numeric_limits<double>::max(), dis_min, cluster;
        for(size_t j=0; j<k; ++j) {
            dis_min=pow(vv.at(i).at(0)-u.at(j*2), 2.0)+pow(vv.at(i).at(1)-u.at(j*2+1), 2.0);
            if(dis_min < dis) {
                dis=dis_min;
                cluster=j;
            }
        }
    vv.at(i).at(cluster+2)=1;
    }
}

template <class T1, class T2>
void UpdateMeans(const std::vector<std::vector<T1> > &vv, std::vector<T2> &u, const int &k) {
    std::vector<T2> sum_set(u.size(), 0);

    for(size_t i=0; i<k; ++i) {
        double sum_indi=0.0;
        for(size_t j=0; j<vv.size(); ++j) {
            sum_indi+=vv.at(j).at(i+2);
            sum_set.at(i*2)+=vv.at(j).at(i+2)*vv.at(j).at(0);
            sum_set.at(i*2+1)+=vv.at(j).at(i+2)*vv.at(j).at(1);
        }
        sum_set.at(i*2)/=sum_indi;
        sum_set.at(i*2+1)/=sum_indi;
    }
    u=sum_set;
}

template <class T1, class T2>
double DistortionMeasure(const std::vector<std::vector<T1> > &vv, const std::vector<T2> &u, const int &k) {
    double cost=0.0;
    for(size_t i=0; i<vv.size(); ++i) {
        for(size_t j=0; j<k; ++j) {
            cost+=vv.at(i).at(j+2)*(pow(vv.at(i).at(0)-u.at(j*2), 2.0)+pow(vv.at(i).at(1)-u.at(j*2+1), 2.0));
        }
    }

    return cost;
}

int main() {
    int k=4;
    double mean[]={39, 42, 70, 2, 230, 10, 190, 85};
    std::vector<double> u(mean, mean+sizeof(mean)/sizeof(mean[0]));

    std::string oridata="kmeans.dat";
    std::vector<std::vector<double> > vv_data;

    ReadDataFromFile(oridata, vv_data);

    AddIndicator(vv_data, k);

    std::cout<<"the original mean: \n";
    for(std::vector<double>::const_iterator it=u.begin(); it!=u.end(); ++it) {
        std::cout<<*it<<" ";
    }
    std::cout<<std::endl;

    double cost_old=std::numeric_limits<double>::max();
    while(true) {
        double cost_new=DistortionMeasure(vv_data, u, k);

        if(std::abs(cost_new-cost_old)<0.0000001)
            break;

        UpdateIndicator(vv_data, u, k);

        UpdateMeans(vv_data, u, k);
        cost_old=cost_new;
    }

    std::cout<<"the new mean: \n";
    for(std::vector<double>::const_iterator it=u.begin(); it!=u.end(); ++it) {
        std::cout<<*it<<" ";
    }
    std::cout<<std::endl;

    return 0;
}

The two phases of re-assigning data points to clusters and re-computing the cluster means are repeated in turn until there is no further change in the assignments(or  until some maximum number of iterations is exceeded).

posted @   东宫得臣  阅读(149)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列:基于图像分类模型对图像进行分类
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 25岁的心里话
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现
点击右上角即可分享
微信分享提示