import json import numpy as np import torch from scipy.cluster.vq import kmeans2,kmeans from sklearn.cluster import k_means import random with open("info.json", "r") as f: lables = np.array(json.load(f)) # 计算标准 image_size = 640 # 获得原始的框大小 lable_image_width_height = lables[:, [5, 6]] # 这里需要把宽高等比缩放到640 lable_image_width_height_max_line = lable_image_width_height.max(axis=1, keepdims=True) std_image_width_height = image_size * lable_image_width_height / lable_image_width_height_max_line # 计算缩放的对应anchor的宽高 std_box_wh = std_image_width_height * lables[:, [3, 4]] keep_index = (std_box_wh >= 2).any(axis=1) keep_wh = std_box_wh[keep_index] # 白化 9 (stage=8,16,32,w/h=0.5,1.0,2.0),并且聚类9个anchor num_anchor = 9 keep_wh_std = keep_wh.std(axis=0) whiten_wh = keep_wh / keep_wh_std k, _ = kmeans(whiten_wh, num_anchor) # 按照面积排序 new_anchor = k * keep_wh_std new_anchor = new_anchor[new_anchor.prod(axis=1).argsort()] keep_wh = torch.FloatTensor(keep_wh) new_anchor = torch.FloatTensor(new_anchor) # 计算宽宽比,高高比 anchor_t = 4 # 使用广播机制 keep_wh【N,2】 new_anchor[9,2]计算除法 ratio = keep_wh[:, None] / new_anchor[None] box_div_anchor = ratio anchor_div_box = 1 / ratio # 计算宽宽比,高高比最大值 max_ratio = torch.max(box_div_anchor, anchor_div_box).max(2)[0] # match_cond【N,K】 match_cond = max_ratio < anchor_t # 计算bpr只有有一个就为真,指标 bpr = match_cond.any(1).float().mean() # 适应的指标 min_ratio = torch.min(box_div_anchor, anchor_div_box).min(2)[0] min_ratio = min_ratio.max(1)[0] fitness = (min_ratio * (min_ratio>1/anchor_t)).float().mean() def fitness(box,anchor): ratio = box[:, None] / anchor[None] box_div_anchor = ratio anchor_div_box = 1 / ratio min_ratio = torch.min(box_div_anchor, anchor_div_box).min(2)[0] min_ratio = min_ratio.max(1)[0] return (min_ratio * (min_ratio > 1 / anchor_t)).float().mean() def bpr(box,anchor): ratio = box[:, None] / anchor[None] box_div_anchor = ratio anchor_div_box = 1 / ratio min_ratio = torch.min(box_div_anchor, anchor_div_box).min(2)[0] min_ratio = min_ratio.max(1)[0] return (min_ratio > 1 / anchor_t).float().mean() #使用遗传算法 iter_count = 1000 anchor_shape = new_anchor.shape curren_fitness = fitness(keep_wh,new_anchor) curren_bpr = bpr(keep_wh,new_anchor) print(f"curren_bpr {curren_bpr:.5f}") for _ in range(iter_count): mutate_coeff = torch.ones_like(new_anchor) while (mutate_coeff==1).all(): mutate_range = (torch.rand(anchor_shape)<0.9)*np.random.random() * torch.randn(anchor_shape) mutate_coeff = (mutate_range*0.1+1).clamp(0.3,3.0) mutate_anchor = (new_anchor*mutate_coeff).clamp(2.0) mutate_fitness = fitness(keep_wh,mutate_anchor) print(f"mutate_fitness: {mutate_fitness}") if mutate_fitness>curren_fitness: curren_fitness = mutate_fitness new_anchor = mutate_anchor print(bpr(keep_wh,new_anchor))
anchor默认有三个选择(stage :8,16,32,width/height=0.5,1.0,2.0)
如果bpr<0.9则会使用上面的聚类算法和遗传机制选择出比较好的anchor