YOLOv3中K-Means聚类出新数据集的Anchor尺寸

参考博客:

聚类kmeans算法在yolov3中的应用 https://www.cnblogs.com/sdu20112013/p/10937717.html

这篇博客写得非常详细,也贴出了github代码:https://github.com/AlexeyAB/darknet/blob/master/scripts/gen_anchors.py

整体代码如下:

  1 '''
  2 Created on Feb 20, 2017
  3 @author: jumabek
  4 '''
  5 from os import listdir
  6 from os.path import isfile, join
  7 import argparse
  8 # import cv2
  9 import numpy as np
 10 import sys
 11 import os
 12 import shutil
 13 import random
 14 import math
 15 width_in_cfg_file = 416.
 16 height_in_cfg_file = 416.
 17 
 18 
 19 def IOU(x, centroids):
 20     '''
 21     :param x: 当前gt的w和h
 22     :param centroids: 质心
 23     :return:当前gt与每个质心的相似度,np.array形式
 24     '''
 25     similarities = []
 26     k = len(centroids)
 27     for centroid in centroids:
 28         c_w, c_h = centroid
 29         w, h = x
 30         if c_w >= w and c_h >= h:
 31             similarity = w * h / (c_w * c_h)
 32         elif c_w >= w and c_h <= h:
 33             similarity = w * c_h / (w * h + (c_w - w) * c_h)  # 交叉面积/总面积
 34         elif c_w <= w and c_h >= h:
 35             similarity = c_w * h / (w * h + c_w * (c_h - h))
 36         else:  # means both w,h are bigger than c_w and c_h respectively
 37             similarity = (c_w * c_h) / (w * h)
 38         similarities.append(similarity)  # will become (k,) shape
 39     return np.array(similarities)
 40 
 41 
 42 def avg_IOU(X, centroids):
 43     n, d = X.shape
 44     sum = 0.
 45     for i in range(X.shape[0]):
 46         # note IOU() will return array which contains IoU for each centroid and X[i] // slightly ineffective, but I am too lazy
 47         sum += max(IOU(X[i], centroids))
 48     return sum / n
 49 
 50 
 51 def write_anchors_to_file(centroids, X, anchor_file):
 52     f = open(anchor_file, 'w')
 53     anchors = centroids.copy()
 54     print(anchors.shape)
 55     for i in range(anchors.shape[0]):
 56         anchors[i][0] *= width_in_cfg_file  # / 32. YOLOv3不用除以32 # 归一化后的宽高乘以预设的图片宽高
 57         anchors[i][1] *= height_in_cfg_file  # / 32.
 58     widths = anchors[:, 0]
 59     sorted_indices = np.argsort(widths)
 60     print('Anchors = ', anchors[sorted_indices])
 61     for i in sorted_indices[:-1]:
 62         # 将前n-1个anchor写入txt
 63         f.write('%0.2f,%0.2f, ' % (anchors[i, 0], anchors[i, 1]))
 64     # there should not be comma after last anchor, that's why
 65     # 最后一个anchor写完以后需要换行,所以单独填写
 66     f.write('%0.2f,%0.2f\n' % (anchors[sorted_indices[-1:], 0], anchors[sorted_indices[-1:], 1]))
 67     f.write('%f\n' % (avg_IOU(X, centroids)))
 68     print()
 69 
 70 
 71 def kmeans(X, centroids, eps, anchor_file):
 72     '''
 73 
 74     :param X: annotation_dims,所有的标注信息中的宽和高
 75     :param centroids: 随机生成的质心
 76     :param eps:
 77     :param anchor_file: 保存结果的文件
 78     :return:
 79     '''
 80     N = X.shape[0]
 81     iterations = 0
 82     k, dim = centroids.shape
 83     prev_assignments = np.ones(N) * (-1)
 84     iter = 0
 85     old_D = np.zeros((N, k))
 86     while True:
 87         D = []
 88         iter += 1
 89         for i in range(N):
 90             # 计算gt框与质心之间的距离,相似度越大,说明当前gt越接近于质心,此距离就应该越小
 91             d = 1 - IOU(X[i], centroids)
 92             D.append(d)
 93         D = np.array(D)  # D.shape = (N,k)
 94         print("iter {}: dists = {}".format(iter, np.sum(np.abs(old_D - D))))
 95         # assign samples to centroids
 96         assignments = np.argmin(D, axis=1)  # 返回每一行的最小值的下标.即当前样本应该归为k个质心中的哪一个质心.
 97         if (assignments == prev_assignments).all():  # 质心已经不再变化
 98             print("Centroids = ", centroids)
 99             write_anchors_to_file(centroids, X, anchor_file)
100             return
101         # calculate new centroids,更新质心
102         centroid_sums = np.zeros((k, dim), np.float)
103         for i in range(N):
104             centroid_sums[assignments[i]] += X[i]
105         for j in range(k):
106             centroids[j] = centroid_sums[j] / (np.sum(assignments == j))
107         prev_assignments = assignments.copy()
108         old_D = D.copy()
109 
110 
111 def main(argv):
112     parser = argparse.ArgumentParser()
113     parser.add_argument('-filelist', default='F://BaiduNetdiskDownload//trainall_name.txt',
114                         help='path to filelist\n')
115     parser.add_argument('-output_dir', default='F://BaiduNetdiskDownload//generated_anchors//anchors//', type=str,
116                         help='Output anchor directory\n')
117     parser.add_argument('-num_clusters', default=6, type=int,
118                         help='number of clusters\n')
119     args = parser.parse_args()
120     if not os.path.exists(args.output_dir):
121         os.mkdir(args.output_dir)
122     f = open(args.filelist)
123     lines = [line.rstrip('\n') for line in f.readlines()]
124     annotation_dims = []
125     size = np.zeros((1, 1, 3))
126     for line in lines:
127         # 注意路径问题,通过替换图片路径中的Images为labels来找到标签信息
128         line = line.replace('Images','labels')
129         # line = line.replace('img1','labels')
130         # line = line.replace('JPEGImages', 'labels')
131         line = line.replace('.jpg', '.txt')
132         line = line.replace('.png', '.txt')
133         print(line)
134 
135         f2 = open(line)
136         for line in f2.readlines():
137             line = line.rstrip('\n')
138             w, h = line.split(' ')[3:]  # 得到标注文件的宽和高[0 0.83984 0.40700 0.17188 0.47218]
139             # print(w,h)
140             annotation_dims.append(tuple(map(float, (w, h))))
141     annotation_dims = np.array(annotation_dims)
142     eps = 0.005
143     if args.num_clusters == 0:
144         for num_clusters in range(1, 11):  # we make 1 through 10 clusters
145             anchor_file = join(args.output_dir, 'anchors%d.txt' % (num_clusters))
146             indices = [random.randrange(annotation_dims.shape[0]) for i in range(num_clusters)]
147             centroids = annotation_dims[indices]
148             kmeans(annotation_dims, centroids, eps, anchor_file)
149             print('centroids.shape', centroids.shape)
150     else:
151         anchor_file = join(args.output_dir, 'anchors%d.txt' % (args.num_clusters))  # 保存结果的文件
152         # 在所有labels数量范围内随机生成质心的索引数,生成num_clusters个
153         indices = [random.randrange(annotation_dims.shape[0]) for i in range(args.num_clusters)]
154         # 生成质心
155         centroids = annotation_dims[indices]
156         # 调用kmeans
157         kmeans(annotation_dims, centroids, eps, anchor_file)
158         print('centroids.shape', centroids.shape)
159 
160 
161 if __name__ == "__main__":
162     main(sys.argv)

使用生成YOLOv3 anchor时需要注意

anchors[i][0] *= width_in_cfg_file  # / 32. YOLOv3不用除以32 # 归一化后的宽高乘以预设的图片宽高

最后生成的结果,6个anchors:

7.90,21.48, 18.72,61.61, 34.67,138.55, 65.49,251.30, 104.70,64.11, 144.33,434.60
0.582349

可以看出宽高比都为1:3左右,结合我使用的是行人检测的数据集,这个比例还算正常。但第5组数据(104.70,64.11)不符合这个宽高比

 

posted @ 2020-08-02 12:31  DJames23  阅读(1375)  评论(0编辑  收藏  举报