如何将MMrotate的识别结果转换为dota和fair1m格式
问题来源,在使用mmrotate的过程中,需要能够对识别的结果进行推断,结果发现缺乏相关功能:
From the demo i know show_result_pyplot can plot the inferred results, I would like to ask how to convert inferred results to DOTA format, is there a related function? Or do you need to handle the result directly?
thanks!
Hi @jsxyhelu, we did not provide the corresponding script to the user. You need to convert the results to DOTA format by yourself. Welcome to submit your script to help more people.
thanks!
Hi @jsxyhelu, we did not provide the corresponding script to the user. You need to convert the results to DOTA format by yourself. Welcome to submit your script to help more people.
那么就自己来设计实现相关功能:
一、数据格式:
mmrotate的输出格式为:
分别为: x, y, w, h, theta, score.
目标格式为Dota采用 txt 文件存放,
其中一个标注框对应为: x1、 y1、 x2、 y2、 x3、 y3、 x4、 y4、 classname、diffcult 。注意这里没有归一化处理
二、批量处理和保存
首先将result保存下来
import os
import numpy as np
src_label_root = '/root/mmrotate/demo/ssdd_tiny/images/'
dst_label_root = '/root/mmrotate/demo/ssdd_tiny/dst/'
!mkdir '/root/mmrotate/demo/ssdd_tiny/dst/'
model.cfg = cfg
for i, src_label_name in enumerate(os.listdir(src_label_root)):
src_label_path = os.path.join(src_label_root,src_label_name) #输入地址
dst_label_path = os.path.join(dst_label_root,os.path.splitext(src_label_name)[0]+".txt")
img = mmcv.imread(src_label_path)
result = inference_detector(model, img)
np.savetxt(dst_label_path, result[0], delimiter=',')
print(dst_label_path)
而后进行格式转换,对于单通道图片来说为:
def rota( x, y, w, h, a): # 旋转中心点,旋转中心点,框的w,h,旋转角
center_x1 = x
center_y1 = y
x1, y1 = x - w / 2, y - h / 2 # 旋转前左上
x2, y2 = x + w / 2, y - h / 2 # 旋转前右上
x3, y3 = x + w / 2, y + h / 2 # 旋转前右下
x4, y4 = x - w / 2, y + h / 2 # 旋转前左下
px1 = (x1 - center_x1) * math.cos(a) - (y1 - center_y1) * math.sin(a) + center_x1 # 旋转后左上
py1 = (x1 - center_x1) * math.sin(a) + (y1 - center_y1) * math.cos(a) + center_y1
px2 = (x2 - center_x1) * math.cos(a) - (y2 - center_y1) * math.sin(a) + center_x1 # 旋转后右上
py2 = (x2 - center_x1) * math.sin(a) + (y2 - center_y1) * math.cos(a) + center_y1
px3 = (x3 - center_x1) * math.cos(a) - (y3 - center_y1) * math.sin(a) + center_x1 # 旋转后右下
py3 = (x3 - center_x1) * math.sin(a) + (y3 - center_y1) * math.cos(a) + center_y1
px4 = (x4 - center_x1) * math.cos(a) - (y4 - center_y1) * math.sin(a) + center_x1 # 旋转后左下
py4 = (x4 - center_x1) * math.sin(a) + (y4 - center_y1) * math.cos(a) + center_y1
return px1, py1, px2, py2, px3, py3, px4, py4 # 旋转后的四个点,左上,右上,右下,左下
def mmrotate2dota(src_img_root, src_label_root, dst_label_root,class_map,score_thr=0.3):
not_have_img = []
if not os.path.exists(dst_label_root):
os.makedirs(dst_label_root)
# 遍历所有txt文件
for i, src_label_name in enumerate(os.listdir(src_label_root)):
src_label_path = os.path.join(src_label_root,src_label_name) #输入地址
dst_label_path = os.path.join(dst_label_root,src_label_name) #输出地址
dst_label_list = [] ## 空列表
with open(src_label_path, 'r') as fr:
txtlines = fr.readlines() #原始数据
for line in txtlines:
oneline = line.strip().split(",")
x = float(oneline[0])
y = float(oneline[1])
w = float(oneline[2])
h = float(oneline[3])
a = float(oneline[4])
score = float(oneline[5])
px1, py1, px2, py2, px3, py3, px4, py4 = rota(x,y,w,h,a)
#目标格式为 x1、y1、x2、y2、x3、y3、x4、y4、 classname、diffcult
dstline = str(px1)+" "+ str(py1)+" "+ str(px2)+" "+ str(py2)+" "+ str(px3)+" "+ str(py3)+" "+ str(px4)+" "+ str(py4)+" "+ str(class_map['0'])+ "1"
if(score >= score_thr):
dst_label_list.append(dstline)
with open(dst_label_path,'w') as fw:
fw.writelines([line+'\n' for line in dst_label_list]) #添加换行
print(dst_label_path)
print('convert done')
center_x1 = x
center_y1 = y
x1, y1 = x - w / 2, y - h / 2 # 旋转前左上
x2, y2 = x + w / 2, y - h / 2 # 旋转前右上
x3, y3 = x + w / 2, y + h / 2 # 旋转前右下
x4, y4 = x - w / 2, y + h / 2 # 旋转前左下
px1 = (x1 - center_x1) * math.cos(a) - (y1 - center_y1) * math.sin(a) + center_x1 # 旋转后左上
py1 = (x1 - center_x1) * math.sin(a) + (y1 - center_y1) * math.cos(a) + center_y1
px2 = (x2 - center_x1) * math.cos(a) - (y2 - center_y1) * math.sin(a) + center_x1 # 旋转后右上
py2 = (x2 - center_x1) * math.sin(a) + (y2 - center_y1) * math.cos(a) + center_y1
px3 = (x3 - center_x1) * math.cos(a) - (y3 - center_y1) * math.sin(a) + center_x1 # 旋转后右下
py3 = (x3 - center_x1) * math.sin(a) + (y3 - center_y1) * math.cos(a) + center_y1
px4 = (x4 - center_x1) * math.cos(a) - (y4 - center_y1) * math.sin(a) + center_x1 # 旋转后左下
py4 = (x4 - center_x1) * math.sin(a) + (y4 - center_y1) * math.cos(a) + center_y1
return px1, py1, px2, py2, px3, py3, px4, py4 # 旋转后的四个点,左上,右上,右下,左下
def mmrotate2dota(src_img_root, src_label_root, dst_label_root,class_map,score_thr=0.3):
not_have_img = []
if not os.path.exists(dst_label_root):
os.makedirs(dst_label_root)
# 遍历所有txt文件
for i, src_label_name in enumerate(os.listdir(src_label_root)):
src_label_path = os.path.join(src_label_root,src_label_name) #输入地址
dst_label_path = os.path.join(dst_label_root,src_label_name) #输出地址
dst_label_list = [] ## 空列表
with open(src_label_path, 'r') as fr:
txtlines = fr.readlines() #原始数据
for line in txtlines:
oneline = line.strip().split(",")
x = float(oneline[0])
y = float(oneline[1])
w = float(oneline[2])
h = float(oneline[3])
a = float(oneline[4])
score = float(oneline[5])
px1, py1, px2, py2, px3, py3, px4, py4 = rota(x,y,w,h,a)
#目标格式为 x1、y1、x2、y2、x3、y3、x4、y4、 classname、diffcult
dstline = str(px1)+" "+ str(py1)+" "+ str(px2)+" "+ str(py2)+" "+ str(px3)+" "+ str(py3)+" "+ str(px4)+" "+ str(py4)+" "+ str(class_map['0'])+ "1"
if(score >= score_thr):
dst_label_list.append(dstline)
with open(dst_label_path,'w') as fw:
fw.writelines([line+'\n' for line in dst_label_list]) #添加换行
print(dst_label_path)
print('convert done')
得到初步的对比结果,目视是正确的
使用Dota自己的工具进行标绘(Dota_devKit)
具体
查看 https://www.kaggle.com/code/jsxyhelu2019/ddd-mmrotate-result2dota
三、获得批量处理结果
当前的结果处理的只是一种类型的,在处理批量数据的时候是有不同的。
而且转换的过程中存在错误,需要进行修正。
通过模仿现有的例子,能够获得读取现有pt,执行推断的结果。
它的内容是这样来组织的:
一共37个array,每一个都是推测出来的位置。
这样的话在写下来的过程中,就需要编码了。
而且在推断的过程中,就是需要使用
from mmrotate.apis import inference_detector_by_patches
img = 'demo/dota_demo.jpg'
result = inference_detector_by_patches(model, img, [1024], [824], [1.0], 0.1)
def inference_detector_by_patches(model,
img,
sizes,
steps,
ratios,
merge_iou_thr,
bs=1):
"""inference patches with the detector.
Split huge image(s) into patches and inference them with the detector.
Finally, merge patch results on one huge image by nms.
Args:
model (nn.Module): The loaded detector.
img (str | ndarray or): Either an image file or loaded image.
sizes (list): The sizes of patches.
steps (list): The steps between two patches.
ratios (list): Image resizing ratios for multi-scale detecting.
merge_iou_thr (float): IoU threshold for merging results.
bs (int): Batch size, must greater than or equal to 1.
Returns:
list[np.ndarray]: Detection results.
"""
所以最后,单个写:
# Use the detector to do inference
dst = []
from mmrotate.apis import inference_detector_by_patches
img = '/home/helu/workstation/Fair1m/fair1M_jpg_train_split_1280_200/images/1__1__0___0.jpg'
result = inference_detector_by_patches(model, img, [1024], [824], [1.0], 0.1)
for index,typeresult in enumerate(result):
if(typeresult.size!=0):
for lineresult in typeresult:
lineresult = np.append(lineresult, np.float32(index))
dst.append(lineresult)
#print(index)
print(dst)
#show_result_pyplot(model, img, result, score_thr=0.3)
批量处理,获得Dota的结果
test_image_root = '/home/helu/workstation/Fair1m/fair1M_jpg_test_tiny/images/'
test_result_root = '/home/helu/workstation/Fair1m/fair1M_jpg_test_tiny/labelTxt/'
dst = []
from mmrotate.apis import inference_detector_by_patches
for i, test_image_name in enumerate(os.listdir(test_image_root)):
dst = []
test_image_path = os.path.join(test_image_root,test_image_name) #输入地址
dst_label_path = os.path.join(test_result_root,os.path.splitext(test_image_name)[0]+".txt")
img = mmcv.imread(test_image_path)
result = inference_detector_by_patches(model, img, [1024], [824], [1.0], 0.1)
for index,typeresult in enumerate(result):
if(typeresult.size!=0):
for lineresult in typeresult:
lineresult = np.append(lineresult, np.float32(index))
dst.append(lineresult)
np.savetxt(dst_label_path, dst, delimiter=',')
全部代码为 https://files.cnblogs.com/files/blogs/758212/MMRotat_infer.rar?t=1682895717&download=true
需要进行进一步的修改,或者数据转换也可以。
转换为Fair1m数据格式并上分,30 epoch获得这个值
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!