《机器学习实战》读书笔记—k近邻算法c语言实现(win下)
#include <stdio.h> #include <io.h> #include <math.h> #include <stdlib.h> #define K 10 //kNN中选取最近邻居的个数 #define LINE 1024 //每个文件字符数 const char *to_search_train = "F:\\kNN\\train\\*.txt"; //train数据地址 const char *to_search_test = "F:\\kNN\\test\\*.txt"; //test数据地址 struct //定义结构体 储存train数据和标签 { int train_mat[2000][LINE]; //矩阵每一行都是1*LINE的矩阵 int train_label[2000]; //储存每个数据的标签 }Train; struct //定义结构体 储存test数据和标签 { int test_mat[2000][LINE]; int test_label[2000]; }Test; float mat_dist[1000][2000]; //定义距离矩阵,储存每个test数据到每个train数据的距离 //定义子函数,功能:将每个test数据与所有train数据的距离进行排序,选取距离最小的前K个,这K个数据标签类型最多的,将此标签返回给主函数 int BubbleSort(float mat_dist_row[], int label[],int num_train) { int i,j,k; int temp,temp_label; int num[K] = {0}; int max = 0; int label_final = 0; for(i = 0;i < num_train;i++) //冒泡排序,距离从小到大,同时将label对应跟随 { for(j = i+1;j < num_train;j++) { if(mat_dist_row[i] > mat_dist_row[j]) { temp = mat_dist_row[i]; mat_dist_row[i] = mat_dist_row[j]; mat_dist_row[j] = temp; temp_label = label[i]; label[i] = label[j]; label[j] = temp_label; } } } for (k = 0;k < K;k++) //统计前K个数据,各种标签的个数 { switch(label[k]) { case 0: num[0]++;break; case 1: num[1]++;break; case 2: num[2]++;break; case 3: num[3]++;break; case 4: num[4]++;break; case 5: num[5]++;break; case 6: num[6]++;break; case 7: num[7]++;break; case 8: num[8]++;break; case 9: num[9]++;break; default: break; } } max = num[0]; for(i = 0;i < K;i++) //标签类型最多的,选择次标签 { if (num[i] > max ) { max = num[i]; label_final = i; } } return label_final; } int main() { FILE *fp; int c; //用于逐个读入字符数据 int train_i = 0,train_j = 0,test_i = 0,test_j = 0; //用于循环 int count_train = 0,count_test = 0; //用于统计train和test文件的个数 int i,j,k,l; //用于循环 int sum = 0; //距离求和 int update_label[2000]; //每次调用函数,更新label int classifier; //记录返回的标签类型 int count = 0; //错误的个数 float rate; //错误率 char str_adr[255]; //fopen函数读入文件名时地址 long handle; //用于查找的句柄 struct _finddata_t fileinfo; //文件信息的结构体 handle = _findfirst(to_search_train,&fileinfo); //第一次查找 sprintf(str_adr, "F:\\kNN\\train\\%s", fileinfo.name); //文件名赋给str_adr if(-1 == handle) { printf("File not exit\n"); } else { switch(fileinfo.name[0]) //给第一个文件赋予标签 { case '0': Train.train_label[count_train] = 0;break; case '1': Train.train_label[count_train] = 1;break; case '2': Train.train_label[count_train] = 2;break; case '3': Train.train_label[count_train] = 3;break; case '4': Train.train_label[count_train] = 4;break; case '5': Train.train_label[count_train] = 5;break; case '6': Train.train_label[count_train] = 6;break; case '7': Train.train_label[count_train] = 7;break; case '8': Train.train_label[count_train] = 8;break; case '9': Train.train_label[count_train] = 9;break; default: break; } count_train++; if((fp = fopen(str_adr,"r")) == NULL) printf("Error!Can't open the file!\n"); else //将文件中'0'和'1'字符转化为数字0,1,并储存 { while((c = fgetc(fp)) != EOF) { if(c == '0' || c == '1') { Train.train_mat[train_i][train_j] = c - '0'; train_j++; } } } fclose(fp); while(!_findnext(handle,&fileinfo)) //循环查找其他符合的文件,知道找不到其他的为止 { train_j = 0; train_i++; sprintf(str_adr, "F:\\kNN\\train\\%s", fileinfo.name); switch(fileinfo.name[0]) //给后面文件赋予标签 { case '0': Train.train_label[count_train] = 0;break; case '1': Train.train_label[count_train] = 1;break; case '2': Train.train_label[count_train] = 2;break; case '3': Train.train_label[count_train] = 3;break; case '4': Train.train_label[count_train] = 4;break; case '5': Train.train_label[count_train] = 5;break; case '6': Train.train_label[count_train] = 6;break; case '7': Train.train_label[count_train] = 7;break; case '8': Train.train_label[count_train] = 8;break; case '9': Train.train_label[count_train] = 9;break; default: break; } if((fp = fopen(str_adr,"r")) == NULL) printf("Error!Can't open the file!\n"); else //将文件中'0'和'1'字符转化为数字0,1,并储存 { while((c = fgetc(fp)) != EOF) { if(c == '0' || c == '1') { Train.train_mat[train_i][train_j] = c - '0'; train_j++; } } } count_train++; fclose(fp); } } _findclose(handle); //下面重复上面文件读入和储存的过程,读入并储存所有test数据 handle = _findfirst(to_search_test,&fileinfo); sprintf(str_adr, "F:\\kNN\\test\\%s", fileinfo.name); if(-1 == handle) { printf("File not exit\n"); } else { switch(fileinfo.name[0]) { case '0': {Test.test_label[count_test] = 0;break;} case '1': {Test.test_label[count_test] = 1;break;} case '2': {Test.test_label[count_test] = 2;break;} case '3': {Test.test_label[count_test] = 3;break;} case '4': {Test.test_label[count_test] = 4;break;} case '5': {Test.test_label[count_test] = 5;break;} case '6': {Test.test_label[count_test] = 6;break;} case '7': {Test.test_label[count_test] = 7;break;} case '8': {Test.test_label[count_test] = 8;break;} case '9': {Test.test_label[count_test] = 9;break;} default: break; } count_test++; if((fp = fopen(str_adr,"r")) == NULL) printf("Error!Can't open the file!\n"); else { while((c = fgetc(fp)) != EOF) { if(c == '0' || c == '1') { Test.test_mat[test_i][test_j] = c - '0'; test_j++; } } } fclose(fp); while(!_findnext(handle,&fileinfo)) { test_j = 0; test_i++; sprintf(str_adr, "F:\\kNN\\test\\%s", fileinfo.name); switch(fileinfo.name[0]) { case '0': {Test.test_label[count_test] = 0;break;} case '1': {Test.test_label[count_test] = 1;break;} case '2': {Test.test_label[count_test] = 2;break;} case '3': {Test.test_label[count_test] = 3;break;} case '4': {Test.test_label[count_test] = 4;break;} case '5': {Test.test_label[count_test] = 5;break;} case '6': {Test.test_label[count_test] = 6;break;} case '7': {Test.test_label[count_test] = 7;break;} case '8': {Test.test_label[count_test] = 8;break;} case '9': {Test.test_label[count_test] = 9;break;} default: break; } if((fp = fopen(str_adr,"r")) == NULL) printf("Error!Can't open the file!\n"); else { while((c = fgetc(fp)) != EOF) { if(c == '0' || c == '1') { Test.test_mat[test_i][test_j] = c - '0'; test_j++; } } } count_test++; fclose(fp); } } _findclose(handle); for (i = 0;i < count_test;i++) //计算每个test(循环中的i)数据到每个train(循环中的j)数据的距离 { for (j = 0;j < count_train;j++) { for (k = 0;k < LINE;k++) { sum =sum + (Test.test_mat[i][k]-Train.train_mat[j][k])*(Test.test_mat[i][k]-Train.train_mat[j][k]); } mat_dist[i][j] = sqrt(sum); sum = 0; } for (l = 0;l < count_train;l++) //更新train数据的label { update_label[l] = Train.train_label[l]; } classifier = BubbleSort(mat_dist[i],update_label,count_train);//调用子函数,得到第i个test数据的标签 if (Test.test_label[i] != classifier) //统计错误个数 { count++; } printf("the real answer is: %d, the classififier is: %d\n",Test.test_label[i],classifier);//打印 } rate = (float)count/count_test; //计算错误率 printf("the total number of errors is: %d\n",count); //打印 printf("the total error rate is: %f\n",rate); return 0; }
干了将近一周才把这个程序写出来,其中遇到了很多很多问题,下面做一点总结:
1、读入文件中的数据不熟悉。在读入txt文件上耗费了太多的时间。
2、对数组、指针了解太少。大数组要定义在外边作为全局就不会“太大”,数组传递以后实参也会改变。
3、对新定义的变量,能赋初值的就赋上初值。
4、杜绝编译、运行检验错误的思想,要觉得没问题了,再去编译和运行。
几个尚未解决的问题:
1、数组如何定义不会“太大”;
2、程序中读入的程序如何变成通用的子函数;
3、指针不会用;
4、找大神帮着改一下提高效率。
这些问题要解决!