手写识别——KNN

#include <iostream>
#include<map>  
#include<vector>  
#include<stdio.h>  
#include<cmath>  
#include<cstdlib>  
#include<algorithm>  
#include<fstream>
 
using namespace std;
 
typedef char tLabel;  
typedef double tData;  
typedef pair<int,double>  PAIR;  
const int colLen = 2;//导入新的数据集时只需要修改行列参数  
const int rowLen = 6;  
ifstream fin;  
ofstream fout;
 
class KNN
{
private:
        tData dataSet[rowLen][colLen];    //用数组定义样本集
        tLabel labels[rowLen];
        tData testData[colLen];  
        int k;  
        map<int,double> map_index_dis;  
        map<tLabel,int> map_label_freq;  
        double get_distance(tData *d1,tData *d2);    //计算两两样本间距离函数
public:
        KNN(int k);  //构造函数
  
        void get_all_distance();  
          
        void get_max_freq_label();  
  
        struct CmpByValue  
        {  
            bool operator() (const PAIR& lhs,const PAIR& rhs)  
            {  
                return lhs.second < rhs.second;  
            }  
        };      
};
 
KNN::KNN(int k)
{
    this->k = k;
    fin.open("movie_data.txt");//导入新的数据集时只需修改文件名
    if(!fin)
    {
        cout<<"can not open the file data.txt"<<endl;  
        exit(1);    
    }
    for(int i = 0; i < rowLen; i++)
    {
        for(int j = 0;j <colLen; j++)
        {
            fin>>dataSet[i][j];
        }
        fin>>labels[i];
    }
    
    cout<<"please input the test data :"<<endl;  
    //输入测试数据   
    for(int i=0;i<colLen;i++)  
        cin>>testData[i];
}
 
double KNN:: get_distance(tData *d1,tData *d2)  
{  
    double sum = 0;  
    for(int i=0;i<colLen;i++)  
    {  
        sum += pow( (d1[i]-d2[i]) , 2 );  
    }  
  
//  cout<<"the sum is = "<<sum<<endl;  
    return sqrt(sum);  
}
 
//计算测试样本与训练集中每个样本的距离   
void KNN:: get_all_distance()  
{  
    double distance;  
    int i;  
    for(i=0;i<rowLen;i++)  
    {  
        distance = get_distance(dataSet[i],testData);  
        //<key,value> => <i,distance>  
        map_index_dis[i] = distance;  
    }  
    //遍历map,打印各个序号和距离
    map<int,double>::const_iterator it = map_index_dis.begin();  
    while(it!=map_index_dis.end())  
    {  
        cout<<"index = "<<it->first<<" distance = "<<it->second<<endl;  
        it++;  
    }  
}  
   
//在k值设定的情况下,计算测试数据属于哪个lable,并输出   
void KNN:: get_max_freq_label()  
{  
    //将map_index_dis转换为vec_index_dis  
    vector<PAIR> vec_index_dis( map_index_dis.begin(),map_index_dis.end() );   
    //对vec_index_dis进行从低到高排序,以获得最近距离数据
    
    sort(vec_index_dis.begin(),vec_index_dis.end(),CmpByValue());  
  
    for(int i=0;i<k;i++)  
    {  
        cout<<"the index = "<<vec_index_dis[i].first<<" the distance = "<<vec_index_dis[i].second<<" the label = "<<labels[vec_index_dis[i].first]<<" the coordinate ( "<<dataSet[ vec_index_dis[i].first ][0]<<","<<dataSet[ vec_index_dis[i].first ][1]<<" )"<<endl;  
        //calculate the count of each label  
        map_label_freq[ labels[ vec_index_dis[i].first ]  ]++;  
    }  
  
    map<tLabel,int>::const_iterator map_it = map_label_freq.begin();  
    tLabel label;  
    int max_freq = 0;  
    //find the most frequent label  
    while( map_it != map_label_freq.end() )  
    {  
        if( map_it->second > max_freq )  
        {  
            max_freq = map_it->second;  
            label = map_it->first;  
        }  
        map_it++;  
    }  
    cout<<"The test data belongs to the "<<label<<" label"<<endl;  
}  
  
int main()  
{  
    int k ;  
    cout<<"please input the k value : "<<endl;  
    cin>>k;  
    KNN knn(k);  
    knn.get_all_distance();  
    knn.get_max_freq_label();    
    return 0;  
}  
 
 

 

posted @ 2021-05-02 22:58  兜转转  阅读(45)  评论(0编辑  收藏  举报