对小样本进行数据增强

针对YoloV3 中的训练数据不足的情况,考虑数据增强的方式,同时改变原始数据标注的坐标。

  1 import xml.etree.ElementTree as ET
  2 import os
  3 import numpy as np
  4 from PIL import Image
  5 import shutil
  6 
  7 import imgaug as ia
  8 from imgaug import augmenters as iaa
  9 
 10 
 11 ia.seed(1)
 12 
 13 def read_xml_annotation(root, image_id):
 14     in_file = open(os.path.join(root, image_id))
 15     tree = ET.parse(in_file)
 16     root = tree.getroot()
 17     bndboxlist = []
 18 
 19     for object in root.findall('object'):  # 找到root节点下的所有country节点
 20         bndbox = object.find('bndbox')  # 子节点下节点rank的值
 21 
 22         xmin = int(bndbox.find('xmin').text)
 23         xmax = int(bndbox.find('xmax').text)
 24         ymin = int(bndbox.find('ymin').text)
 25         ymax = int(bndbox.find('ymax').text)
 26         # print(xmin,ymin,xmax,ymax)
 27         bndboxlist.append([xmin, ymin, xmax, ymax])
 28         # print(bndboxlist)
 29 
 30     bndbox = root.find('object').find('bndbox')
 31     return bndboxlist
 32 
 33 
 34 # (506.0000, 330.0000, 528.0000, 348.0000) -> (520.4747, 381.5080, 540.5596, 398.6603)
 35 def change_xml_annotation(root, image_id, new_target):
 36     new_xmin = new_target[0]
 37     new_ymin = new_target[1]
 38     new_xmax = new_target[2]
 39     new_ymax = new_target[3]
 40 
 41     in_file = open(os.path.join(root, str(image_id) + '.xml'))  # 这里root分别由两个意思
 42     tree = ET.parse(in_file)
 43     xmlroot = tree.getroot()
 44     object = xmlroot.find('object')
 45     bndbox = object.find('bndbox')
 46     xmin = bndbox.find('xmin')
 47     xmin.text = str(new_xmin)
 48     ymin = bndbox.find('ymin')
 49     ymin.text = str(new_ymin)
 50     xmax = bndbox.find('xmax')
 51     xmax.text = str(new_xmax)
 52     ymax = bndbox.find('ymax')
 53     ymax.text = str(new_ymax)
 54     tree.write(os.path.join(root, str("%06d" % (str(id) + '.xml'))))
 55 
 56 
 57 def change_xml_list_annotation(root, image_id, new_target, saveroot, id):
 58     in_file = open(os.path.join(root, str(image_id) + '.xml'))  # 这里root分别由两个意思
 59     tree = ET.parse(in_file)
 60     elem = tree.find('filename')
 61     elem.text = (str("%06d" % int(id)) + '.bmp')
 62     xmlroot = tree.getroot()
 63     index = 0
 64 
 65     for object in xmlroot.findall('object'):  # 找到root节点下的所有country节点
 66         bndbox = object.find('bndbox')  # 子节点下节点rank的值
 67 
 68         # xmin = int(bndbox.find('xmin').text)
 69         # xmax = int(bndbox.find('xmax').text)
 70         # ymin = int(bndbox.find('ymin').text)
 71         # ymax = int(bndbox.find('ymax').text)
 72 
 73         new_xmin = new_target[index][0]
 74         new_ymin = new_target[index][1]
 75         new_xmax = new_target[index][2]
 76         new_ymax = new_target[index][3]
 77 
 78         xmin = bndbox.find('xmin')
 79         xmin.text = str(new_xmin)
 80         ymin = bndbox.find('ymin')
 81         ymin.text = str(new_ymin)
 82         xmax = bndbox.find('xmax')
 83         xmax.text = str(new_xmax)
 84         ymax = bndbox.find('ymax')
 85         ymax.text = str(new_ymax)
 86 
 87         index = index + 1
 88 
 89     tree.write(os.path.join(saveroot, str("%06d" % int(id)) + '.xml'))
 90 
 91 
 92 def mkdir(path):
 93     # 去除首位空格
 94     path = path.strip()
 95     # 去除尾部 \ 符号
 96     path = path.rstrip("/")
 97     # 判断路径是否存在
 98     # 存在     True
 99     # 不存在   False
100     isExists = os.path.exists(path)
101     # 判断结果
102     if not isExists:
103         # 如果不存在则创建目录
104         # 创建目录操作函数
105         os.makedirs(path)
106         print(path + ' 创建成功')
107         return True
108     else:
109         # 如果目录存在则不创建,并提示目录已存在
110         print(path + ' 目录已存在')
111         return False
112 
113 
114 if __name__ == "__main__":
115 
116     IMG_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/images/"
117     XML_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/Anotations/"
118 
119     AUG_XML_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/Anotations_Augment/"  # 存储增强后的XML文件夹路径
120     try:
121         shutil.rmtree(AUG_XML_DIR)
122     except FileNotFoundError as e:
123         a = 1
124     mkdir(AUG_XML_DIR)
125 
126     AUG_IMG_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/images_Augment/"  # 存储增强后的影像文件夹路径
127     try:
128         shutil.rmtree(AUG_IMG_DIR)
129     except FileNotFoundError as e:
130         a = 1
131     mkdir(AUG_IMG_DIR)
132 
133     AUGLOOP = 10  # 每张影像增强的数量
134 
135     boxes_img_aug_list = []
136     new_bndbox = []
137     new_bndbox_list = []
138 
139     # 影像增强
140     seq = iaa.Sequential([
141         iaa.Flipud(0.5),  # vertically flip 20% of all images
142         iaa.Fliplr(0.5),  # 镜像
143         iaa.Multiply((1.2, 1.5)),  # change brightness, doesn't affect BBs
144         iaa.GaussianBlur(sigma=(0, 2.0)),  # iaa.GaussianBlur(0.5),
145         iaa.Affine(
146             translate_px={"x": 15, "y": 15},
147             scale=(0.8, 0.95),
148             rotate=(-30, 30)
149         )  # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs
150     ])
151 
152 
153     for root, sub_folders, files in os.walk(XML_DIR):
154 
155         nameCnt =0 
156 
157         for name in files:
158 
159             bndbox = read_xml_annotation(XML_DIR, name)
160             shutil.copy(os.path.join(XML_DIR, name), AUG_XML_DIR)
161             shutil.copy(os.path.join(IMG_DIR, name[:-4] + '.bmp'), AUG_IMG_DIR)
162 
163             for epoch in range(AUGLOOP):
164                 seq_det = seq.to_deterministic()  # 保持坐标和图像同步改变,而不是随机
165                 # 读取图片
166                 img = Image.open(os.path.join(IMG_DIR, name[:-4] + '.bmp'))
167                 # sp = img.size
168                 img = np.asarray(img)
169                 # bndbox 坐标增强
170                 for i in range(len(bndbox)):
171                     bbs = ia.BoundingBoxesOnImage([
172                         ia.BoundingBox(x1=bndbox[i][0], y1=bndbox[i][1], x2=bndbox[i][2], y2=bndbox[i][3]),
173                     ], shape=img.shape)
174 
175                     bbs_aug = seq_det.augment_bounding_boxes([bbs])[0]
176                     boxes_img_aug_list.append(bbs_aug)
177 
178                     # new_bndbox_list:[[x1,y1,x2,y2],...[],[]]
179                     n_x1 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x1)))
180                     n_y1 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y1)))
181                     n_x2 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x2)))
182                     n_y2 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y2)))
183                     if n_x1 == 1 and n_x1 == n_x2:
184                         n_x2 += 1
185                     if n_y1 == 1 and n_y2 == n_y1:
186                         n_y2 += 1
187                     if n_x1 >= n_x2 or n_y1 >= n_y2:
188                         print('error', name)
189                     new_bndbox_list.append([n_x1, n_y1, n_x2, n_y2])
190                 # 存储变化后的图片
191                 image_aug = seq_det.augment_images([img])[0]
192                 path = os.path.join(AUG_IMG_DIR,
193                                     str("%06d" % int(float((len(files) + nameCnt + epoch * 250)))) + '.bmp')
194                 image_auged = bbs.draw_on_image(image_aug, thickness=0)
195                 Image.fromarray(image_auged).save(path)
196 
197                 # 存储变化后的XML
198                 change_xml_list_annotation(XML_DIR, name[:-4], new_bndbox_list, AUG_XML_DIR,
199                                            len(files) +nameCnt + epoch * 250)
200                 print(str("%06d" % (len(files) +nameCnt + epoch * 250)) + '.bmp')
201                 new_bndbox_list = []
202 
203                 nameCnt +=1
View Code

  参考:https://blog.csdn.net/weixin_45829462/article/details/105951949

posted @ 2022-03-14 13:29  赵家小伙儿  阅读(249)  评论(0编辑  收藏  举报