darknet分类源码解析
1 void validate_classifier_multi(char *datacfg, char *filename, char *weightfile) 2 { 3 int i, j; 4 network net = parse_network_cfg(filename); 5 set_batch_network(&net, 1); 6 if(weightfile){ 7 load_weights(&net, weightfile); 8 } 9 srand(time(0)); 10 11 list *options = read_data_cfg(datacfg);//读.data文件到option列表中 12 13 char *label_list = option_find_str(options, "labels", "data/labels.list"); 14 //从读到的.data生成的option列表去找对饮的字段如labels,将labels的配置路径放到label_list指针中, 15 //然后如果labels的配置路径是"data/labels.list",打印“使用默认配置”字样 16 char *valid_list = option_find_str(options, "valid", "data/train.list");// l,key,def; return def 17 int classes = option_find_int(options, "classes", 2); 18 int topk = option_find_int(options, "top", 1); 19 if (topk > classes) topk = classes;//找的比类别还多 20 21 char **labels = get_labels(label_list); 22 //将labels.list标签名读到lables字符指针,可以通过labels[i]访问标签 23 list *plist = get_paths(valid_list);//得到验证集的数据路径 24 int scales[] = {224, 288, 320, 352, 384}; 25 int nscales = sizeof(scales)/sizeof(scales[0]); 26 27 char **paths = (char **)list_to_array(plist); 28 int m = plist->size; 29 free_list(plist); 30 31 float avg_acc = 0; 32 float avg_topk = 0; 33 int* indexes = (int*)calloc(topk, sizeof(int)); 34 35 for(i = 0; i < m; ++i){ 36 int class_id = -1;//一般用负数初始化 37 char *path = paths[i];//这里的路径名包括文件名之外的路径吗? 38 for(j = 0; j < classes; ++j){ 39 if(strstr(path, labels[j])){ 40 //在path字符串中查找labels[j]字符串第一次出现的位置 41 class_id = j; 42 //这里实现了数据集在训练过程中的类别的确定。还是看匹配,只要标签在文件名中 43 break; 44 } 45 } 46 float* pred = (float*)calloc(classes, sizeof(float)); 47 image im = load_image_color(paths[i], 0, 0); 48 for(j = 0; j < nscales; ++j){ 49 image r = resize_min(im, scales[j]); 50 resize_network(&net, r.w, r.h); 51 float *p = network_predict(net, r.data); 52 if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy, 1); 53 axpy_cpu(classes, 1, p, 1, pred, 1); 54 flip_image(r); 55 p = network_predict(net, r.data); 56 axpy_cpu(classes, 1, p, 1, pred, 1); 57 if(r.data != im.data) free_image(r); 58 } 59 free_image(im); 60 top_k(pred, classes, topk, indexes); 61 free(pred); 62 if(indexes[0] == class_id) avg_acc += 1; 63 for(j = 0; j < topk; ++j){ 64 if(indexes[j] == class_id) avg_topk += 1; 65 } 66 67 printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1)); 68 } 69 } 70 71 72 void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top) 73 {//反初始化主要是类对象的析构 74 network net = parse_network_cfg_custom(cfgfile, 1, 0); 75 if(weightfile){ 76 load_weights(&net, weightfile); 77 } 78 set_batch_network(&net, 1); 79 srand(2222222); 80 81 fuse_conv_batchnorm(net); 82 calculate_binary_weights(net); 83 84 list *options = read_data_cfg(datacfg); 85 86 char *name_list = option_find_str(options, "names", 0); 87 if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list"); 88 int classes = option_find_int(options, "classes", 2); 89 if (top == 0) top = option_find_int(options, "top", 1); 90 if (top > classes) top = classes; 91 92 int i = 0; 93 char **names = get_labels(name_list); 94 clock_t time; 95 int* indexes = (int*)calloc(top, sizeof(int)); 96 char buff[256]; 97 char *input = buff; 98 //int size = net.w; 99 while(1){ 100 if(filename){ 101 strncpy(input, filename, 256);//将filename的前256个字符复制到input中。 102 }else{ 103 printf("Enter Image Path: "); 104 fflush(stdout); 105 input = fgets(input, 256, stdin); 106 if(!input) return; 107 strtok(input, "\n"); 108 } 109 image im = load_image_color(input, 0, 0); 110 image r = letterbox_image(im, net.w, net.h); 111 //image r = resize_min(im, size); 112 //resize_network(&net, r.w, r.h); 113 printf("%d %d\n", r.w, r.h); 114 115 float *X = r.data; 116 time=clock(); 117 float *predictions = network_predict(net, X); 118 if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 0); 119 top_k(predictions, net.outputs, top, indexes); 120 //按得分来排top k,indexes是新的排序指针,按升序排列,prediction越大的在indexes里面的id越是靠后。 121 printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); 122 for(i = 0; i < top; ++i){ 123 int index = indexes[i]; 124 //hierarchy是一个树形结构体指针变量。应该是没有的。 125 if(net.hierarchy) printf("%d, %s: %f, parent: %s \n",index, names[index], predictions[index], (net.hierarchy->parent[index] >= 0) ? names[net.hierarchy->parent[index]] : "Root"); 126 else printf("%s: %f\n",names[index], predictions[index]); 127 //names[index]是分类的对应的类别名称如yb,ye,yf 128 //predictions[index]是推理置信度 129 } 130 if(r.data != im.data) free_image(r); 131 free_image(im); 132 if (filename) break;//可以批量测试,如果filename是False,跳出 133 } 134 } 135