对小样本进行数据增强
针对YoloV3 中的训练数据不足的情况,考虑数据增强的方式,同时改变原始数据标注的坐标。

1 import xml.etree.ElementTree as ET 2 import os 3 import numpy as np 4 from PIL import Image 5 import shutil 6 7 import imgaug as ia 8 from imgaug import augmenters as iaa 9 10 11 ia.seed(1) 12 13 def read_xml_annotation(root, image_id): 14 in_file = open(os.path.join(root, image_id)) 15 tree = ET.parse(in_file) 16 root = tree.getroot() 17 bndboxlist = [] 18 19 for object in root.findall('object'): # 找到root节点下的所有country节点 20 bndbox = object.find('bndbox') # 子节点下节点rank的值 21 22 xmin = int(bndbox.find('xmin').text) 23 xmax = int(bndbox.find('xmax').text) 24 ymin = int(bndbox.find('ymin').text) 25 ymax = int(bndbox.find('ymax').text) 26 # print(xmin,ymin,xmax,ymax) 27 bndboxlist.append([xmin, ymin, xmax, ymax]) 28 # print(bndboxlist) 29 30 bndbox = root.find('object').find('bndbox') 31 return bndboxlist 32 33 34 # (506.0000, 330.0000, 528.0000, 348.0000) -> (520.4747, 381.5080, 540.5596, 398.6603) 35 def change_xml_annotation(root, image_id, new_target): 36 new_xmin = new_target[0] 37 new_ymin = new_target[1] 38 new_xmax = new_target[2] 39 new_ymax = new_target[3] 40 41 in_file = open(os.path.join(root, str(image_id) + '.xml')) # 这里root分别由两个意思 42 tree = ET.parse(in_file) 43 xmlroot = tree.getroot() 44 object = xmlroot.find('object') 45 bndbox = object.find('bndbox') 46 xmin = bndbox.find('xmin') 47 xmin.text = str(new_xmin) 48 ymin = bndbox.find('ymin') 49 ymin.text = str(new_ymin) 50 xmax = bndbox.find('xmax') 51 xmax.text = str(new_xmax) 52 ymax = bndbox.find('ymax') 53 ymax.text = str(new_ymax) 54 tree.write(os.path.join(root, str("%06d" % (str(id) + '.xml')))) 55 56 57 def change_xml_list_annotation(root, image_id, new_target, saveroot, id): 58 in_file = open(os.path.join(root, str(image_id) + '.xml')) # 这里root分别由两个意思 59 tree = ET.parse(in_file) 60 elem = tree.find('filename') 61 elem.text = (str("%06d" % int(id)) + '.bmp') 62 xmlroot = tree.getroot() 63 index = 0 64 65 for object in xmlroot.findall('object'): # 找到root节点下的所有country节点 66 bndbox = object.find('bndbox') # 子节点下节点rank的值 67 68 # xmin = int(bndbox.find('xmin').text) 69 # xmax = int(bndbox.find('xmax').text) 70 # ymin = int(bndbox.find('ymin').text) 71 # ymax = int(bndbox.find('ymax').text) 72 73 new_xmin = new_target[index][0] 74 new_ymin = new_target[index][1] 75 new_xmax = new_target[index][2] 76 new_ymax = new_target[index][3] 77 78 xmin = bndbox.find('xmin') 79 xmin.text = str(new_xmin) 80 ymin = bndbox.find('ymin') 81 ymin.text = str(new_ymin) 82 xmax = bndbox.find('xmax') 83 xmax.text = str(new_xmax) 84 ymax = bndbox.find('ymax') 85 ymax.text = str(new_ymax) 86 87 index = index + 1 88 89 tree.write(os.path.join(saveroot, str("%06d" % int(id)) + '.xml')) 90 91 92 def mkdir(path): 93 # 去除首位空格 94 path = path.strip() 95 # 去除尾部 \ 符号 96 path = path.rstrip("/") 97 # 判断路径是否存在 98 # 存在 True 99 # 不存在 False 100 isExists = os.path.exists(path) 101 # 判断结果 102 if not isExists: 103 # 如果不存在则创建目录 104 # 创建目录操作函数 105 os.makedirs(path) 106 print(path + ' 创建成功') 107 return True 108 else: 109 # 如果目录存在则不创建,并提示目录已存在 110 print(path + ' 目录已存在') 111 return False 112 113 114 if __name__ == "__main__": 115 116 IMG_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/images/" 117 XML_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/Anotations/" 118 119 AUG_XML_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/Anotations_Augment/" # 存储增强后的XML文件夹路径 120 try: 121 shutil.rmtree(AUG_XML_DIR) 122 except FileNotFoundError as e: 123 a = 1 124 mkdir(AUG_XML_DIR) 125 126 AUG_IMG_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/images_Augment/" # 存储增强后的影像文件夹路径 127 try: 128 shutil.rmtree(AUG_IMG_DIR) 129 except FileNotFoundError as e: 130 a = 1 131 mkdir(AUG_IMG_DIR) 132 133 AUGLOOP = 10 # 每张影像增强的数量 134 135 boxes_img_aug_list = [] 136 new_bndbox = [] 137 new_bndbox_list = [] 138 139 # 影像增强 140 seq = iaa.Sequential([ 141 iaa.Flipud(0.5), # vertically flip 20% of all images 142 iaa.Fliplr(0.5), # 镜像 143 iaa.Multiply((1.2, 1.5)), # change brightness, doesn't affect BBs 144 iaa.GaussianBlur(sigma=(0, 2.0)), # iaa.GaussianBlur(0.5), 145 iaa.Affine( 146 translate_px={"x": 15, "y": 15}, 147 scale=(0.8, 0.95), 148 rotate=(-30, 30) 149 ) # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs 150 ]) 151 152 153 for root, sub_folders, files in os.walk(XML_DIR): 154 155 nameCnt =0 156 157 for name in files: 158 159 bndbox = read_xml_annotation(XML_DIR, name) 160 shutil.copy(os.path.join(XML_DIR, name), AUG_XML_DIR) 161 shutil.copy(os.path.join(IMG_DIR, name[:-4] + '.bmp'), AUG_IMG_DIR) 162 163 for epoch in range(AUGLOOP): 164 seq_det = seq.to_deterministic() # 保持坐标和图像同步改变,而不是随机 165 # 读取图片 166 img = Image.open(os.path.join(IMG_DIR, name[:-4] + '.bmp')) 167 # sp = img.size 168 img = np.asarray(img) 169 # bndbox 坐标增强 170 for i in range(len(bndbox)): 171 bbs = ia.BoundingBoxesOnImage([ 172 ia.BoundingBox(x1=bndbox[i][0], y1=bndbox[i][1], x2=bndbox[i][2], y2=bndbox[i][3]), 173 ], shape=img.shape) 174 175 bbs_aug = seq_det.augment_bounding_boxes([bbs])[0] 176 boxes_img_aug_list.append(bbs_aug) 177 178 # new_bndbox_list:[[x1,y1,x2,y2],...[],[]] 179 n_x1 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x1))) 180 n_y1 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y1))) 181 n_x2 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x2))) 182 n_y2 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y2))) 183 if n_x1 == 1 and n_x1 == n_x2: 184 n_x2 += 1 185 if n_y1 == 1 and n_y2 == n_y1: 186 n_y2 += 1 187 if n_x1 >= n_x2 or n_y1 >= n_y2: 188 print('error', name) 189 new_bndbox_list.append([n_x1, n_y1, n_x2, n_y2]) 190 # 存储变化后的图片 191 image_aug = seq_det.augment_images([img])[0] 192 path = os.path.join(AUG_IMG_DIR, 193 str("%06d" % int(float((len(files) + nameCnt + epoch * 250)))) + '.bmp') 194 image_auged = bbs.draw_on_image(image_aug, thickness=0) 195 Image.fromarray(image_auged).save(path) 196 197 # 存储变化后的XML 198 change_xml_list_annotation(XML_DIR, name[:-4], new_bndbox_list, AUG_XML_DIR, 199 len(files) +nameCnt + epoch * 250) 200 print(str("%06d" % (len(files) +nameCnt + epoch * 250)) + '.bmp') 201 new_bndbox_list = [] 202 203 nameCnt +=1
参考:https://blog.csdn.net/weixin_45829462/article/details/105951949
分类:
目标检测
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律