k-means算法求解anchors (针对YOLO3)
文字内容以后再补充:
import numpy as np
# 定义Box类,描述bounding box的坐标
class Box():
def __init__(self, x, y, w, h):
self.x = x
self.y = y
self.w = w
self.h = h
def box_iou(a, b):
'''
# a和b都是Box类型实例
# 返回值area是box a 和box b 的交集面积
'''
a_x1 = a.x-a.w/2
a_y1 = a.y - a.h / 2
a_x2 = a.x+a.w/2
a_y2 = a.y + a.h / 2
b_x1 = b.x-b.w/2
b_y1 = b.y - b.h / 2
b_x2 = b.x+b.w/2
b_y2 = b.y + b.h / 2
box_x1 = max(a_x1,b_x1)
box_y1 = max(a_y1, b_y1)
box_x2 = min(a_x2,b_x2)
box_y2 = min(a_y2, b_y2)
box_w = box_x2-box_x1
box_h = box_y2 - box_y1
if box_w < 0 or box_h < 0:
area = 0
else:
area = box_w * box_h
box_intersection=area
box_union = a.w * a.h + b.w * b.h-box_intersection
iou = box_intersection/box_union
return iou
# 使用k-means ++ 初始化 centroids,减少随机初始化的centroids对最终结果的影响
def init_centroids(boxes, n_anchors):
'''
随机选择一个box作为
:param boxes: 是所有bounding boxes的Box对象列表
:param n_anchors: n_anchors是k-means的k值
:return: 返回值centroids 是初始化的n_anchors个centroid
'''
centroids = []
boxes_num = len(boxes)
centroid_index = np.random.choice(boxes_num, 1) # 在boxes_num=55 中产生一个数23
centroids.append(boxes[centroid_index])
print(centroids[0].w, centroids[0].h)
for centroid_index in range(0, n_anchors-1):
sum_distance = 0
distance_list = []
cur_sum = 0
for box in boxes:
min_distance = 1
for centroid_i, centroid in enumerate(centroids):
distance = (1 - box_iou(box, centroid))
if distance < min_distance:
min_distance = distance
sum_distance += min_distance
distance_list.append(min_distance)
distance_thresh = sum_distance*np.random.random()
for i in range(0, boxes_num):
cur_sum += distance_list[i]
if cur_sum > distance_thresh:
centroids.append(boxes[i])
print(boxes[i].w, boxes[i].h)
break
return centroids
# 进行 k-means 计算新的centroids
def do_kmeans(n_anchors, boxes, centroids):
'''
:param n_anchors: 是k-means的k值
:param boxes: 是所有bounding boxes的Box对象列表
:param centroids: 是所有簇的中心
:return: # 返回值new_centroids 是计算出的新簇中心
# 返回值groups是n_anchors个簇包含的boxes的列表
# 返回值loss是所有box距离所属的最近的centroid的距离的和
'''
loss = 0
groups = []
new_centroids = []
for i in range(n_anchors):
groups.append([]) # [[], [], [], []]
new_centroids.append(Box(0, 0, 0, 0))
# 以上代码建立初始化
for box in boxes:
min_distance = 1
group_index = 0
for centroid_index, centroid in enumerate(centroids):
# 这个循环实际是在找box与哪个centroidsiou最小,最接近
distance = (1 - box_iou(box, centroid))
if distance < min_distance:
min_distance = distance
group_index = centroid_index
groups[group_index].append(box) # 将其保留对应的族中
loss += min_distance
new_centroids[group_index].w += box.w # 累加对应的族中的w
new_centroids[group_index].h += box.h
for i in range(n_anchors): # 得到新的族中的w与h
new_centroids[i].w /= len(groups[i])
new_centroids[i].h /= len(groups[i])
return new_centroids, groups, loss
def init_all_value(use_init_centroids=1, n_anchors=9):
# 构建初始化族中心
if use_init_centroids:
centroids = init_centroids(boxes, n_anchors)
else:
centroid_indices = np.random.choice(len(boxes), n_anchors)
centroids = []
for centroid_index in centroid_indices:
centroids.append(boxes[centroid_index])
# 构建初始化 groups 保存对应族的box类
centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)
return centroids, groups, old_loss
if __name__=='__main__':
# 构建boxes
boxes=[]
boxes.append(Box(4,5,6,7)) # 根据实际情况自己填写
num_anchor = 9 # 产生族中心点是多少
# 构建停止条件
num_iterations=2000
loss_stop = 1e-6
centroids, groups, old_loss = init_all_value(1, num_anchor)
# 循环找到族中最好的w与h
iterations = 1
while (True):
centroids, groups, loss = do_kmeans(num_anchor, boxes, centroids)
iterations = iterations + 1
print("loss = %f" % loss)
if abs(old_loss - loss) < loss_stop or iterations > num_iterations:
break
old_loss = loss
# 打印最终结果
for centroid in centroids:
print("k-means result:\n")
print(centroid.w, centroid.h)