ruijiege

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::
import json
import numpy as np
import torch
from scipy.cluster.vq import kmeans2,kmeans
from sklearn.cluster import k_means
import random

with open("info.json", "r") as f:
    lables = np.array(json.load(f))

# 计算标准
image_size = 640
# 获得原始的框大小
lable_image_width_height = lables[:, [5, 6]]

# 这里需要把宽高等比缩放到640
lable_image_width_height_max_line = lable_image_width_height.max(axis=1, keepdims=True)
std_image_width_height = image_size * lable_image_width_height / lable_image_width_height_max_line

# 计算缩放的对应anchor的宽高
std_box_wh = std_image_width_height * lables[:, [3, 4]]
keep_index = (std_box_wh >= 2).any(axis=1)
keep_wh = std_box_wh[keep_index]

# 白化 9 (stage=8,16,32,w/h=0.5,1.0,2.0),并且聚类9个anchor
num_anchor = 9
keep_wh_std = keep_wh.std(axis=0)
whiten_wh = keep_wh / keep_wh_std
k, _ = kmeans(whiten_wh, num_anchor)

# 按照面积排序
new_anchor = k * keep_wh_std
new_anchor = new_anchor[new_anchor.prod(axis=1).argsort()]

keep_wh = torch.FloatTensor(keep_wh)
new_anchor = torch.FloatTensor(new_anchor)

# 计算宽宽比,高高比
anchor_t = 4
# 使用广播机制 keep_wh【N,2】 new_anchor[9,2]计算除法
ratio = keep_wh[:, None] / new_anchor[None]
box_div_anchor = ratio
anchor_div_box = 1 / ratio
# 计算宽宽比,高高比最大值
max_ratio = torch.max(box_div_anchor, anchor_div_box).max(2)[0]
# match_cond【N,K】
match_cond = max_ratio < anchor_t

# 计算bpr只有有一个就为真,指标
bpr = match_cond.any(1).float().mean()

# 适应的指标

min_ratio = torch.min(box_div_anchor, anchor_div_box).min(2)[0]
min_ratio = min_ratio.max(1)[0]
fitness = (min_ratio * (min_ratio>1/anchor_t)).float().mean()

def fitness(box,anchor):
    ratio = box[:, None] / anchor[None]
    box_div_anchor = ratio
    anchor_div_box = 1 / ratio
    min_ratio = torch.min(box_div_anchor, anchor_div_box).min(2)[0]
    min_ratio = min_ratio.max(1)[0]
    return (min_ratio * (min_ratio > 1 / anchor_t)).float().mean()

def bpr(box,anchor):
    ratio = box[:, None] / anchor[None]
    box_div_anchor = ratio
    anchor_div_box = 1 / ratio
    min_ratio = torch.min(box_div_anchor, anchor_div_box).min(2)[0]
    min_ratio = min_ratio.max(1)[0]
    return (min_ratio > 1 / anchor_t).float().mean()

#使用遗传算法

iter_count = 1000
anchor_shape = new_anchor.shape

curren_fitness = fitness(keep_wh,new_anchor)
curren_bpr = bpr(keep_wh,new_anchor)
print(f"curren_bpr {curren_bpr:.5f}")

for _ in range(iter_count):
    mutate_coeff = torch.ones_like(new_anchor)
    while (mutate_coeff==1).all():
        mutate_range = (torch.rand(anchor_shape)<0.9)*np.random.random() * torch.randn(anchor_shape)
        mutate_coeff = (mutate_range*0.1+1).clamp(0.3,3.0)
    mutate_anchor = (new_anchor*mutate_coeff).clamp(2.0)
    mutate_fitness = fitness(keep_wh,mutate_anchor)
    print(f"mutate_fitness: {mutate_fitness}")
    if mutate_fitness>curren_fitness:
        curren_fitness = mutate_fitness
        new_anchor = mutate_anchor
print(bpr(keep_wh,new_anchor))

anchor默认有三个选择(stage :8,16,32,width/height=0.5,1.0,2.0)

如果bpr<0.9则会使用上面的聚类算法和遗传机制选择出比较好的anchor

posted on 2022-11-14 11:40  哦哟这个怎么搞  阅读(81)  评论(0编辑  收藏  举报