psenet的eval_ctw1500.py解析
评价模型指标,自带的eval_ctw1500.py是根据两两框重叠面积大于0.5算正例tp。其中用到了import Polygon as plg 模块,可以方便的处理多边形的重叠计算面积
1.pip install Polygon2
2.cover = set() 这里用到了set,之前在c++里面也看到过这个的,一查发现差不多功能,就是没有重复的无序的
file_util.py
import os
def read_dir(root):
file_path_list = []
for file_path, dirs, files in os.walk(root):
for file in files:
file_path_list.append(os.path.join(file_path, file).replace('\\', '/'))
file_path_list.sort()
return file_path_list
def read_file(file_path):
file_object = open(file_path, 'r')
file_content = file_object.read()
file_object.close()
return file_content
def write_file(file_path, file_content):
if file_path.find('/') != -1:
father_dir = '/'.join(file_path.split('/')[0:-1])
if not os.path.exists(father_dir):
os.makedirs(father_dir)
file_object = open(file_path, 'w')
file_object.write(file_content)
file_object.close()
def write_file_not_cover(file_path, file_content):
father_dir = '/'.join(file_path.split('/')[0:-1])
if not os.path.exists(father_dir):
os.makedirs(father_dir)
file_object = open(file_path, 'a')
file_object.write(file_content)
file_object.close()
eval_ctw1500.py
import file_util
import Polygon as plg
import numpy as np
pred_root = '../../outputs/submit_ctw1500/'
gt_root = '../../data/CTW1500/test/text_label_curve/'
def get_pred(path):
lines = file_util.read_file(path).split('\n')
bboxes = []
for line in lines:
if line == '':
continue
bbox = line.split(',')
if len(bbox) % 2 == 1:
print path
bbox = [(int)(x) for x in bbox]
bboxes.append(bbox)
return bboxes
def get_gt(path):
print("###############################path=%s"%(path))
lines = file_util.read_file(path).split('\n')
bboxes = []
for line in lines:
if line == '':
continue
# line = util.str.remove_all(line, '\xef\xbb\xbf')
# gt = util.str.split(line, ',')
gt = line.split(',')
x1 = np.int(gt[0])
y1 = np.int(gt[1])
bbox = [np.int(gt[i]) for i in range(4, 32)]
bbox = np.asarray(bbox) + ([x1, y1] * 14)
bboxes.append(bbox)
return bboxes
def get_union(pD,pG):
areaA = pD.area();
areaB = pG.area();
return areaA + areaB - get_intersection(pD, pG);
def get_intersection(pD,pG):
pInt = pD & pG
if len(pInt) == 0:
return 0
return pInt.area()
if __name__ == '__main__':
th = 0.5
pred_list = file_util.read_dir(pred_root)
tp, fp, npos = 0, 0, 0
for pred_path in pred_list:
preds = get_pred(pred_path)
gt_path = gt_root + pred_path.split('/')[-1]
gts = get_gt(gt_path)
npos += len(gts)
cover = set()
for pred_id, pred in enumerate(preds):
pred = np.array(pred)
pred = pred.reshape(pred.shape[0] / 2, 2)
# if pred.shape[0] <= 2:
# continue
pred_p = plg.Polygon(pred)
flag = False
for gt_id, gt in enumerate(gts):
gt = np.array(gt)
gt = gt.reshape(gt.shape[0] / 2, 2)
gt_p = plg.Polygon(gt)
union = get_union(pred_p, gt_p)
inter = get_intersection(pred_p, gt_p)
if inter * 1.0 / union >= th:
if gt_id not in cover:
flag = True
cover.add(gt_id)
if flag:
tp += 1.0
else:
fp += 1.0
print tp, fp, npos
precision = tp / (tp + fp)
recall = tp / npos
hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall)
print 'p: %.4f, r: %.4f, f: %.4f'%(precision, recall, hmean)
好记性不如烂键盘---点滴、积累、进步!