YOLOV5源码解读-datasets.py
下载好yolov5源码后,打开文件yolov5-master\utils\datasets.py,我这边下载的是最新源码。尽管yolov5的作者一直在更新代码,但都是小更。今天主要看看datasets.py中的dataloader中的内容。YOLOV4相对于V5,V4偏向于学术,V5偏向于应用,原理差不多;但是V5做得很完善,把能想到的应用场景都想到了,代码简洁、安全、可靠。
在datasets.py文件中只要有以下内容:
1、支持从本地、摄像头读取图片,也可以从相机在线读取;
2、各种数据增强(MixUp、Cutout、CutMix、Mosaic、随机擦除、打补丁、HSV空间增强等),参考我这篇博客:
https://www.cnblogs.com/winslam/p/14417867.html
3、数据的快速读取(可以从cache中读取)、数据的校验
(注:源码中,数据的增强,例如:反转、马赛克拼接,相关操作导致图像发生改变,你的label也要改变,这个作者做好了;如果你自己要编辑一些增强功能,记得也要改下label,这个label还是很容易改错的)
可以看到,数据加载作者硬是干了1000多行代码,其中最重要的函数是:
1 #功能:读取、验证数据和标签 2 # path:数据集路径 3 # img_size:640*640 4 # batch_size:咱们显卡8G勉强吃得消 5 # hyp:训练过程超参数,eg:学习率、衰减、动量? 6 # rect:训练过程中会将w!=h的图像padding成“正方形”,这个参数是该功能的控制开关(长方形似乎也可以训练吧,只不过慢一点) 7 # cache_images:原本每次遍历标签比较慢,而缓存为cache文件后可以加快读取数据速度! 8 # single_cls:针对单类别任务时,将label置为0 9 # stride:降采样的系数 10 # pad:功能比较多,eg:Mosica合成四张图的时候,用于补全边角等位置的像素 11 12 class LoadImagesAndLabels(Dataset): # for training/testing
我把所有注释放到下面,有兴趣可以看看:
1 # Dataset utils and dataloaders 2 3 import glob 4 import logging 5 import math 6 import os 7 import random 8 import shutil 9 import time 10 from itertools import repeat 11 from multiprocessing.pool import ThreadPool 12 from pathlib import Path 13 from threading import Thread 14 15 import cv2 16 import numpy as np 17 import torch 18 import torch.nn.functional as F 19 from PIL import Image, ExifTags 20 from torch.utils.data import Dataset 21 from tqdm import tqdm 22 23 from utils.general import xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, resample_segments, \ 24 clean_str 25 from utils.torch_utils import torch_distributed_zero_first 26 27 # Parameters 28 help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' 29 img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp'] # acceptable image suffixes 30 vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes 31 logger = logging.getLogger(__name__) 32 33 # Get orientation exif tag 34 for orientation in ExifTags.TAGS.keys(): 35 if ExifTags.TAGS[orientation] == 'Orientation': 36 break 37 38 # 用于验证数据集 39 def get_hash(files): 40 # Returns a single hash value of a list of files 41 return sum(os.path.getsize(f) for f in files if os.path.isfile(f)) 42 43 44 def exif_size(img): 45 # Returns exif-corrected PIL size 46 s = img.size # (width, height) 47 try: 48 rotation = dict(img._getexif().items())[orientation] 49 if rotation == 6: # rotation 270 50 s = (s[1], s[0]) 51 elif rotation == 8: # rotation 90 52 s = (s[1], s[0]) 53 except: 54 pass 55 56 return s 57 58 59 def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, 60 rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''): 61 # Make sure only the first process in DDP process the dataset first, and the following others can use the cache 62 with torch_distributed_zero_first(rank): 63 dataset = LoadImagesAndLabels(path, imgsz, batch_size, 64 augment=augment, # augment images 65 hyp=hyp, # augmentation hyperparameters 66 rect=rect, # rectangular training 67 cache_images=cache, 68 single_cls=opt.single_cls, 69 stride=int(stride), 70 pad=pad, 71 image_weights=image_weights, 72 prefix=prefix) 73 74 batch_size = min(batch_size, len(dataset)) 75 nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers 76 sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None 77 loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader 78 # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() 79 dataloader = loader(dataset, 80 batch_size=batch_size, 81 num_workers=nw, 82 sampler=sampler, 83 pin_memory=True, 84 collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) 85 return dataloader, dataset 86 87 88 class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): 89 """ Dataloader that reuses workers 90 91 Uses same syntax as vanilla DataLoader 92 """ 93 94 def __init__(self, *args, **kwargs): 95 super().__init__(*args, **kwargs) 96 object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) 97 self.iterator = super().__iter__() 98 99 def __len__(self): 100 return len(self.batch_sampler.sampler) 101 102 def __iter__(self): 103 for i in range(len(self)): 104 yield next(self.iterator) 105 106 107 class _RepeatSampler(object): 108 """ Sampler that repeats forever 109 110 Args: 111 sampler (Sampler) 112 """ 113 114 def __init__(self, sampler): 115 self.sampler = sampler 116 117 def __iter__(self): 118 while True: 119 yield from iter(self.sampler) 120 121 # 加载图片 122 class LoadImages: # for inference 123 def __init__(self, path, img_size=640, stride=32): 124 p = str(Path(path).absolute()) # os-agnostic absolute path 125 if '*' in p: 126 files = sorted(glob.glob(p, recursive=True)) # glob 127 elif os.path.isdir(p): 128 files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir 129 elif os.path.isfile(p): 130 files = [p] # files 131 else: 132 raise Exception(f'ERROR: {p} does not exist') 133 134 images = [x for x in files if x.split('.')[-1].lower() in img_formats] 135 videos = [x for x in files if x.split('.')[-1].lower() in vid_formats] 136 ni, nv = len(images), len(videos) 137 138 self.img_size = img_size 139 self.stride = stride 140 self.files = images + videos 141 self.nf = ni + nv # number of files 142 self.video_flag = [False] * ni + [True] * nv 143 self.mode = 'image' 144 if any(videos): 145 self.new_video(videos[0]) # new video 146 else: 147 self.cap = None 148 assert self.nf > 0, f'No images or videos found in {p}. ' \ 149 f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}' 150 151 def __iter__(self): 152 self.count = 0 153 return self 154 155 def __next__(self): 156 if self.count == self.nf: 157 raise StopIteration 158 path = self.files[self.count] 159 160 if self.video_flag[self.count]: 161 # Read video 162 self.mode = 'video' 163 ret_val, img0 = self.cap.read() 164 if not ret_val: 165 self.count += 1 166 self.cap.release() 167 if self.count == self.nf: # last video 168 raise StopIteration 169 else: 170 path = self.files[self.count] 171 self.new_video(path) 172 ret_val, img0 = self.cap.read() 173 174 self.frame += 1 175 print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='') 176 177 else: 178 # Read image 179 self.count += 1 180 img0 = cv2.imread(path) # BGR 181 assert img0 is not None, 'Image Not Found ' + path 182 print(f'image {self.count}/{self.nf} {path}: ', end='') 183 184 # Padded resize 185 img = letterbox(img0, self.img_size, stride=self.stride)[0] 186 187 # Convert 188 img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 189 img = np.ascontiguousarray(img) 190 191 return path, img, img0, self.cap 192 193 def new_video(self, path): 194 self.frame = 0 195 self.cap = cv2.VideoCapture(path) 196 self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) 197 198 def __len__(self): 199 return self.nf # number of files 200 201 # 从webCam中获取图片 202 class LoadWebcam: # for inference 203 def __init__(self, pipe='0', img_size=640, stride=32): 204 self.img_size = img_size 205 self.stride = stride 206 207 if pipe.isnumeric(): 208 pipe = eval(pipe) # local camera 209 # pipe = 'rtsp://192.168.1.64/1' # IP camera 210 # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login 211 # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera 212 213 self.pipe = pipe 214 self.cap = cv2.VideoCapture(pipe) # video capture object 215 self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size 216 217 def __iter__(self): 218 self.count = -1 219 return self 220 221 def __next__(self): 222 self.count += 1 223 if cv2.waitKey(1) == ord('q'): # q to quit 224 self.cap.release() 225 cv2.destroyAllWindows() 226 raise StopIteration 227 228 # Read frame 229 if self.pipe == 0: # local camera 230 ret_val, img0 = self.cap.read() 231 img0 = cv2.flip(img0, 1) # flip left-right 232 else: # IP camera 233 n = 0 234 while True: 235 n += 1 236 self.cap.grab() 237 if n % 30 == 0: # skip frames 238 ret_val, img0 = self.cap.retrieve() 239 if ret_val: 240 break 241 242 # Print 243 assert ret_val, f'Camera Error {self.pipe}' 244 img_path = 'webcam.jpg' 245 print(f'webcam {self.count}: ', end='') 246 247 # Padded resize 248 img = letterbox(img0, self.img_size, stride=self.stride)[0] 249 250 # Convert 251 img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 252 img = np.ascontiguousarray(img) 253 254 return img_path, img, img0, None 255 256 def __len__(self): 257 return 0 258 259 # 加载视频流 260 class LoadStreams: # multiple IP or RTSP cameras 261 def __init__(self, sources='streams.txt', img_size=640, stride=32): 262 self.mode = 'stream' 263 self.img_size = img_size 264 self.stride = stride 265 266 if os.path.isfile(sources): 267 with open(sources, 'r') as f: 268 sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())] 269 else: 270 sources = [sources] 271 272 n = len(sources) 273 self.imgs = [None] * n 274 self.sources = [clean_str(x) for x in sources] # clean source names for later 275 for i, s in enumerate(sources): 276 # Start the thread to read frames from the video stream 277 print(f'{i + 1}/{n}: {s}... ', end='') 278 cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s) 279 assert cap.isOpened(), f'Failed to open {s}' 280 w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 281 h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 282 fps = cap.get(cv2.CAP_PROP_FPS) % 100 283 _, self.imgs[i] = cap.read() # guarantee first frame 284 thread = Thread(target=self.update, args=([i, cap]), daemon=True) 285 print(f' success ({w}x{h} at {fps:.2f} FPS).') 286 thread.start() 287 print('') # newline 288 289 # check for common shapes 290 s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes 291 self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal 292 if not self.rect: 293 print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') 294 295 def update(self, index, cap): 296 # Read next stream frame in a daemon thread 297 n = 0 298 while cap.isOpened(): 299 n += 1 300 # _, self.imgs[index] = cap.read() 301 cap.grab() 302 if n == 4: # read every 4th frame 303 success, im = cap.retrieve() 304 self.imgs[index] = im if success else self.imgs[index] * 0 305 n = 0 306 time.sleep(0.01) # wait time 307 308 def __iter__(self): 309 self.count = -1 310 return self 311 312 def __next__(self): 313 self.count += 1 314 img0 = self.imgs.copy() 315 if cv2.waitKey(1) == ord('q'): # q to quit 316 cv2.destroyAllWindows() 317 raise StopIteration 318 319 # Letterbox 320 img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0] 321 322 # Stack 323 img = np.stack(img, 0) 324 325 # Convert 326 img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416 327 img = np.ascontiguousarray(img) 328 329 return self.sources, img, img0, None 330 331 def __len__(self): 332 return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years 333 334 335 def img2label_paths(img_paths): 336 # Define label paths as a function of image paths 337 # 只需要将label框的信息读取进来 338 sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings 339 return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths] 340 341 #功能:读取、验证数据和标签 342 # path:数据集路径 343 # img_size:640*640 344 # batch_size:咱们显卡8G勉强吃得消 345 # hyp:训练过程超参数,eg:学习率、衰减、动量? 346 # rect:训练过程中会将w!=h的图像padding成“正方形”,这个参数是该功能的控制开关(长方形似乎也可以训练吧,只不过慢一点) 347 # cache_images:原本每次遍历标签比较慢,而缓存为cache文件后可以加快读取数据速度! 348 # single_cls:针对单类别任务时,将label置为0 349 # stride:降采样的系数 350 # pad:功能比较多,eg:Mosica合成四张图的时候,用于补全边角等位置的像素 351 352 class LoadImagesAndLabels(Dataset): # for training/testing 353 def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, 354 cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''): 355 self.img_size = img_size 356 self.augment = augment 357 self.hyp = hyp 358 self.image_weights = image_weights 359 self.rect = False if image_weights else rect 360 # true or false 361 self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training) 362 # 马赛克拼接过程中,希望如何拼接????? 363 self.mosaic_border = [-img_size // 2, -img_size // 2] 364 self.stride = stride 365 self.path = path 366 367 try: 368 f = [] # image files 369 for p in path if isinstance(path, list) else [path]: 370 p = Path(p) # os-agnostic,数据集所在路径 371 if p.is_dir(): # dir 372 # 所有格式的图片全部找到 373 f += glob.glob(str(p / '**' / '*.*'), recursive=True) 374 # f = list(p.rglob('**/*.*')) # pathlib 375 elif p.is_file(): # file 376 with open(p, 'r') as t: 377 t = t.read().strip().splitlines() 378 parent = str(p.parent) + os.sep 379 f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path 380 # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) 381 else: 382 raise Exception(f'{prefix}{p} does not exist') 383 # n = len(self.img_files) 图片数量 384 self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats]) 385 # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib 386 assert self.img_files, f'{prefix}No images found' 387 except Exception as e: 388 raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}') 389 390 # Check cache 后续读取文件会优先检查缓存文件,如果有,就读。 391 self.label_files = img2label_paths(self.img_files) # labels 392 cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels 393 #如果有缓存,那就直接读取就行 394 if cache_path.is_file(): 395 cache, exists = torch.load(cache_path), True # load 396 if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed 397 cache, exists = self.cache_labels(cache_path, prefix), False # re-cache 398 else: 399 cache, exists = self.cache_labels(cache_path, prefix), False # cache 400 401 # Display cache 402 nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total 403 if exists: 404 d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" 405 tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results 406 assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' 407 408 # Read cache 409 cache.pop('hash') # remove hash 410 cache.pop('version') # remove version 411 # 从缓存中读取labels、框的shapes 412 labels, shapes, self.segments = zip(*cache.values()) 413 self.labels = list(labels) 414 self.shapes = np.array(shapes, dtype=np.float64) 415 self.img_files = list(cache.keys()) # update 416 self.label_files = img2label_paths(cache.keys()) # update 417 if single_cls: 418 for x in self.labels: 419 x[:, 0] = 0 420 # 图片数量 421 n = len(shapes) # number of images 422 # batch索引,eg:batch = 8 对应一个epoch 423 bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index 424 # 一个epoch有nb个batch 425 nb = bi[-1] + 1 # number of batches 426 self.batch = bi # batch index of image 427 self.n = n 428 self.indices = range(n) 429 430 # Rectangular Training 431 if self.rect: 432 # Sort by aspect ratio 433 s = self.shapes # wh 434 ar = s[:, 1] / s[:, 0] # aspect ratio 435 irect = ar.argsort() 436 self.img_files = [self.img_files[i] for i in irect] 437 self.label_files = [self.label_files[i] for i in irect] 438 self.labels = [self.labels[i] for i in irect] 439 self.shapes = s[irect] # wh 440 ar = ar[irect] 441 442 # Set training image shapes 443 shapes = [[1, 1]] * nb 444 for i in range(nb): 445 ari = ar[bi == i] 446 mini, maxi = ari.min(), ari.max() 447 if maxi < 1: 448 shapes[i] = [maxi, 1] 449 elif mini > 1: 450 shapes[i] = [1, 1 / mini] 451 452 self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride 453 454 # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) 455 self.imgs = [None] * n 456 if cache_images: 457 gb = 0 # Gigabytes of cached images 458 self.img_hw0, self.img_hw = [None] * n, [None] * n 459 results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads 460 pbar = tqdm(enumerate(results), total=n) 461 for i, x in pbar: 462 self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) 463 gb += self.imgs[i].nbytes 464 pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)' 465 466 def cache_labels(self, path=Path('./labels.cache'), prefix=''): 467 # Cache dataset labels, check images and read shapes 468 x = {} # dict 469 # 读取数据过程中显示一些错误信息,eg:上述可能读取到非图片文件,这时可以打印出错误信息 470 # eg:丢失、找到(正常)、空的、重复 471 nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate 472 # 控制台的进度条! 473 pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) 474 for i, (im_file, lb_file) in enumerate(pbar): 475 try: 476 # verify images 验证图片,eg:不能太小 477 im = Image.open(im_file) 478 im.verify() # PIL verify 479 shape = exif_size(im) # image size 480 segments = [] # instance segments 481 assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' 482 assert im.format.lower() in img_formats, f'invalid image format {im.format}' 483 484 # verify labels 验证标签 485 if os.path.isfile(lb_file): 486 nf += 1 # label found 487 with open(lb_file, 'r') as f: 488 l = [x.split() for x in f.read().strip().splitlines()] 489 if any([len(x) > 8 for x in l]): # is segment 490 classes = np.array([x[0] for x in l], dtype=np.float32) 491 segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) 492 l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) 493 l = np.array(l, dtype=np.float32) 494 if len(l): 495 assert l.shape[1] == 5, 'labels require 5 columns each' 496 assert (l >= 0).all(), 'negative labels' 497 assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' 498 assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' 499 else: 500 ne += 1 # label empty 501 l = np.zeros((0, 5), dtype=np.float32) 502 else: 503 nm += 1 # label missing 504 l = np.zeros((0, 5), dtype=np.float32) 505 x[im_file] = [l, shape, segments] 506 except Exception as e: 507 nc += 1 508 print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') 509 510 pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \ 511 f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" 512 513 if nf == 0: 514 print(f'{prefix}WARNING: No labels found in {path}. See {help_url}') 515 516 x['hash'] = get_hash(self.label_files + self.img_files) 517 x['results'] = nf, nm, ne, nc, i + 1 518 x['version'] = 0.1 # cache version 519 torch.save(x, path) # save for next time 520 logging.info(f'{prefix}New cache created: {path}') 521 return x 522 523 def __len__(self): 524 return len(self.img_files) 525 526 # def __iter__(self): 527 # self.count = -1 528 # print('ran dataset iter') 529 # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF) 530 # return self 531 # 这个函数是训练时,图片向网络中传的时候执行 532 def __getitem__(self, index): 533 index = self.indices[index] # linear, shuffled, or image_weights 534 535 hyp = self.hyp 536 mosaic = self.mosaic and random.random() < hyp['mosaic'] 537 # 马赛克增强 538 if mosaic: 539 # Load mosaic 540 # load_mosaic:如何将四张图合成一张图 541 img, labels = load_mosaic(self, index) 542 shapes = None 543 544 # MixUp https://arxiv.org/pdf/1710.09412.pdf 545 if random.random() < hyp['mixup']: 546 img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1)) 547 r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0 548 img = (img * r + img2 * (1 - r)).astype(np.uint8) 549 labels = np.concatenate((labels, labels2), 0) 550 551 else: 552 # Load image 553 img, (h0, w0), (h, w) = load_image(self, index) 554 555 # Letterbox 556 shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape 557 img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) 558 shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling 559 560 labels = self.labels[index].copy() 561 if labels.size: # normalized xywh to pixel xyxy format 562 labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) 563 564 if self.augment: 565 # Augment imagespace 566 if not mosaic: 567 # 这里数据增强了,标签也会改变,所以也会返回label 568 img, labels = random_perspective(img, labels, 569 degrees=hyp['degrees'], 570 translate=hyp['translate'], 571 scale=hyp['scale'], 572 shear=hyp['shear'], 573 perspective=hyp['perspective']) 574 575 # Augment colorspace 增强色彩空间,色调、饱和度、亮度 576 augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v']) 577 578 # Apply cutouts 579 # if random.random() < 0.9: 580 # labels = cutout(img, labels) 581 582 nL = len(labels) # number of labels 583 if nL: 584 # 用xywh替换xyxy,用于表示框的几何信息 585 labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh 586 labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1 587 labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1 588 589 if self.augment: 590 # flip up-down 图像反转增强 591 if random.random() < hyp['flipud']: 592 img = np.flipud(img) 593 if nL: 594 # 图片反转了,标签也是 595 labels[:, 2] = 1 - labels[:, 2] 596 597 # flip left-right 598 if random.random() < hyp['fliplr']: 599 img = np.fliplr(img) 600 if nL: 601 labels[:, 1] = 1 - labels[:, 1] 602 603 labels_out = torch.zeros((nL, 6)) 604 if nL: 605 labels_out[:, 1:] = torch.from_numpy(labels) 606 607 # Convert 608 img = img[:, :, ::-1].transpose(2, 0, 1) # opencv中BGR 转为 pytorch中的 RGB, to 3x416x416 609 img = np.ascontiguousarray(img) 610 611 return torch.from_numpy(img), labels_out, self.img_files[index], shapes 612 613 @staticmethod 614 def collate_fn(batch): 615 img, label, path, shapes = zip(*batch) # transposed 616 for i, l in enumerate(label): 617 l[:, 0] = i # add target image index for build_targets() 618 return torch.stack(img, 0), torch.cat(label, 0), path, shapes 619 620 @staticmethod 621 def collate_fn4(batch): 622 img, label, path, shapes = zip(*batch) # transposed 623 n = len(shapes) // 4 624 img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n] 625 626 ho = torch.tensor([[0., 0, 0, 1, 0, 0]]) 627 wo = torch.tensor([[0., 0, 1, 0, 0, 0]]) 628 s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale 629 for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW 630 i *= 4 631 if random.random() < 0.5: 632 im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[ 633 0].type(img[i].type()) 634 l = label[i] 635 else: 636 im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) 637 l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s 638 img4.append(im) 639 label4.append(l) 640 641 for i, l in enumerate(label4): 642 l[:, 0] = i # add target image index for build_targets() 643 644 return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4 645 646 647 # Ancillary functions -------------------------------------------------------------------------------------------------- 648 def load_image(self, index): 649 # loads 1 image from dataset, returns img, original hw, resized hw 650 img = self.imgs[index] 651 if img is None: # not cached 652 path = self.img_files[index] 653 img = cv2.imread(path) # BGR 654 assert img is not None, 'Image Not Found ' + path 655 h0, w0 = img.shape[:2] # orig hw 656 r = self.img_size / max(h0, w0) # resize image to img_size 657 if r != 1: # always resize down, only resize up if training with augmentation 658 interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR 659 img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp) 660 return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized 661 else: 662 return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized 663 664 # 色调、饱和度、亮度空间增强 665 def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5): 666 r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains 667 hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) 668 dtype = img.dtype # uint8 669 670 x = np.arange(0, 256, dtype=np.int16) 671 lut_hue = ((x * r[0]) % 180).astype(dtype) 672 lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 673 lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 674 675 img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype) 676 cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed 677 678 679 def hist_equalize(img, clahe=True, bgr=False): 680 # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255 681 yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV) 682 if clahe: 683 c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) 684 yuv[:, :, 0] = c.apply(yuv[:, :, 0]) 685 else: 686 yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram 687 return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB 688 689 # 马赛克增强 690 # 四张图拼接到一起,标签也要重新计算 691 def load_mosaic(self, index): 692 # loads images in a 4-mosaic 693 694 labels4, segments4 = [], [] #list 695 s = self.img_size # 640 696 # 四张图拼接交汇处,交汇为一个点 697 # 这个点我们设置为随机点p(yc, xc),然后四张图怼着这个点拼接 698 # p点在图上被限制在一个中心区域,例如:这个区域是正方形,边长为: 699 # (-x, 2 * s + x),其中x = 320,2 * 640 - 320 700 # 701 yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y 702 # 传进来一个索引为index的图片 703 # 我们需要随机选取另外三张图片,用于拼接 704 indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices 705 for i, index in enumerate(indices): # 这里循环四次,处理四张图 706 # Load image 707 # 四张图放到大图中 708 img, _, (h, w) = load_image(self, index) 709 # 710 # place img in img4 711 # 1.计算初始化大图; 712 # 2.计算当前图片放在大图中什么位置; 713 # 3.计算在小图中去哪一部分放到大图中; 714 if i == 0: # top left 715 # 初始化大图 716 # s = 640 填充像素值为:114 717 img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles 718 # 大图 719 # (x1a, y1a):左上角框的左上顶点,(w, h)是框的宽高,max防止越界 720 # (x2a, y2a):右下角... 721 x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) 722 # 小图 723 x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) 724 elif i == 1: # top right 725 x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc 726 x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h 727 elif i == 2: # bottom left 728 x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) 729 x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) 730 elif i == 3: # bottom right 731 x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) 732 x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) 733 # 将小图img中ROI放到大图中指定位置 734 img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] 735 # 用于大图padding 736 padw = x1a - x1b 737 padh = y1a - y1b 738 739 # 重新计算标值,由于子图都放入大图中 740 # Labels 741 labels, segments = self.labels[index].copy(), self.segments[index].copy() 742 if labels.size: 743 labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format 744 segments = [xyn2xy(x, w, h, padw, padh) for x in segments] 745 labels4.append(labels) 746 segments4.extend(segments) 747 748 # Concat/clip labels 坐标值计算完之后可能越界,调整坐标值,让他们都在大图中。 749 labels4 = np.concatenate(labels4, 0) 750 for x in (labels4[:, 1:], *segments4): 751 np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() 752 # img4, labels4 = replicate(img4, labels4) # replicate 753 754 # 通过上述步骤,得到了一张大图 755 # 对整合的大图进行随机旋转、平移、缩放、裁剪 756 # Augment 数据增强 757 img4, labels4 = random_perspective(img4, labels4, segments4, 758 degrees=self.hyp['degrees'], # 旋转 759 translate=self.hyp['translate'], # 平移 760 scale=self.hyp['scale'], # 缩放 761 shear=self.hyp['shear'], # 剪切 762 perspective=self.hyp['perspective'], # 放射 763 border=self.mosaic_border) # border to remove 764 765 return img4, labels4 766 767 768 def load_mosaic9(self, index): 769 # loads images in a 9-mosaic 770 771 labels9, segments9 = [], [] 772 s = self.img_size 773 indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices 774 for i, index in enumerate(indices): 775 # Load image 776 img, _, (h, w) = load_image(self, index) 777 778 # place img in img9 779 if i == 0: # center 780 img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles 781 h0, w0 = h, w 782 c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates 783 elif i == 1: # top 784 c = s, s - h, s + w, s 785 elif i == 2: # top right 786 c = s + wp, s - h, s + wp + w, s 787 elif i == 3: # right 788 c = s + w0, s, s + w0 + w, s + h 789 elif i == 4: # bottom right 790 c = s + w0, s + hp, s + w0 + w, s + hp + h 791 elif i == 5: # bottom 792 c = s + w0 - w, s + h0, s + w0, s + h0 + h 793 elif i == 6: # bottom left 794 c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h 795 elif i == 7: # left 796 c = s - w, s + h0 - h, s, s + h0 797 elif i == 8: # top left 798 c = s - w, s + h0 - hp - h, s, s + h0 - hp 799 800 padx, pady = c[:2] 801 x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords 802 803 # Labels 804 labels, segments = self.labels[index].copy(), self.segments[index].copy() 805 if labels.size: 806 labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format 807 segments = [xyn2xy(x, w, h, padx, pady) for x in segments] 808 labels9.append(labels) 809 segments9.extend(segments) 810 811 # Image 812 img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax] 813 hp, wp = h, w # height, width previous 814 815 # Offset 816 yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y 817 img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s] 818 819 # Concat/clip labels 820 labels9 = np.concatenate(labels9, 0) 821 labels9[:, [1, 3]] -= xc 822 labels9[:, [2, 4]] -= yc 823 c = np.array([xc, yc]) # centers 824 segments9 = [x - c for x in segments9] 825 826 for x in (labels9[:, 1:], *segments9): 827 np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() 828 # img9, labels9 = replicate(img9, labels9) # replicate 829 830 # Augment 831 img9, labels9 = random_perspective(img9, labels9, segments9, 832 degrees=self.hyp['degrees'], 833 translate=self.hyp['translate'], 834 scale=self.hyp['scale'], 835 shear=self.hyp['shear'], 836 perspective=self.hyp['perspective'], 837 border=self.mosaic_border) # border to remove 838 839 return img9, labels9 840 841 842 def replicate(img, labels): 843 # Replicate labels 844 h, w = img.shape[:2] 845 boxes = labels[:, 1:].astype(int) 846 x1, y1, x2, y2 = boxes.T 847 s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels) 848 for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices 849 x1b, y1b, x2b, y2b = boxes[i] 850 bh, bw = y2b - y1b, x2b - x1b 851 yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y 852 x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh] 853 img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] 854 labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0) 855 856 return img, labels 857 858 859 def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): 860 # Resize and pad image while meeting stride-multiple constraints 861 shape = img.shape[:2] # current shape [height, width] 862 if isinstance(new_shape, int): 863 new_shape = (new_shape, new_shape) 864 865 # Scale ratio (new / old) 866 r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 867 if not scaleup: # only scale down, do not scale up (for better test mAP) 868 r = min(r, 1.0) 869 870 # Compute padding 871 ratio = r, r # width, height ratios 872 new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 873 dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 874 if auto: # minimum rectangle 875 dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding 876 elif scaleFill: # stretch 877 dw, dh = 0.0, 0.0 878 new_unpad = (new_shape[1], new_shape[0]) 879 ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 880 881 dw /= 2 # divide padding into 2 sides 882 dh /= 2 883 884 if shape[::-1] != new_unpad: # resize 885 img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) 886 top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 887 left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 888 img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 889 return img, ratio, (dw, dh) 890 891 # 这里主要是采用opencv去做图像增强,而不是pytorch的模块 892 def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, 893 border=(0, 0)): 894 # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) 895 # targets = [cls, xyxy] 896 897 height = img.shape[0] + border[0] * 2 # shape(h,w,c) 898 width = img.shape[1] + border[1] * 2 899 900 # Center 901 C = np.eye(3) 902 C[0, 2] = -img.shape[1] / 2 # x translation (pixels) 903 C[1, 2] = -img.shape[0] / 2 # y translation (pixels) 904 905 # Perspective 随机构造透视矩阵 906 P = np.eye(3) 907 P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) 908 P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) 909 910 # Rotation and Scale 随机构造旋转、缩放 911 R = np.eye(3) 912 a = random.uniform(-degrees, degrees) 913 # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations 914 s = random.uniform(1 - scale, 1 + scale) 915 # s = 2 ** random.uniform(-scale, scale) 916 R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) 917 918 # Shear 随机构造裁剪 919 S = np.eye(3) 920 S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) 921 S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) 922 923 # Translation 随机构造平移 924 T = np.eye(3) 925 T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels) 926 T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels) 927 928 # Combined rotation matrix 929 # 将上述所有变换集合到 M 矩阵 930 M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT 931 if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed 932 if perspective: 933 img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114)) 934 else: # affine 935 img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) 936 937 # Visualize 938 # import matplotlib.pyplot as plt 939 # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel() 940 # ax[0].imshow(img[:, :, ::-1]) # base 941 # ax[1].imshow(img2[:, :, ::-1]) # warped 942 943 # Transform label coordinates 数据增强导致label也变了,下面就是计算新的label 944 n = len(targets) 945 if n: 946 use_segments = any(x.any() for x in segments) 947 new = np.zeros((n, 4)) 948 if use_segments: # warp segments 949 segments = resample_segments(segments) # upsample 950 for i, segment in enumerate(segments): 951 xy = np.ones((len(segment), 3)) 952 xy[:, :2] = segment 953 xy = xy @ M.T # transform 954 xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine 955 956 # clip 957 new[i] = segment2box(xy, width, height) 958 959 else: # warp boxes 960 xy = np.ones((n * 4, 3)) 961 xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 962 xy = xy @ M.T # transform 963 xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine 964 965 # create new boxes 966 x = xy[:, [0, 2, 4, 6]] 967 y = xy[:, [1, 3, 5, 7]] 968 new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T 969 970 # clip 971 new[:, [0, 2]] = new[:, [0, 2]].clip(0, width) 972 new[:, [1, 3]] = new[:, [1, 3]].clip(0, height) 973 974 # filter candidates 975 i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10) 976 targets = targets[i] 977 targets[:, 1:5] = new[i] 978 979 return img, targets 980 981 982 def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) 983 # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio 984 w1, h1 = box1[2] - box1[0], box1[3] - box1[1] 985 w2, h2 = box2[2] - box2[0], box2[3] - box2[1] 986 ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio 987 return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates 988 989 # 随机裁剪增强 990 def cutout(image, labels): 991 # Applies image cutout augmentation https://arxiv.org/abs/1708.04552 992 h, w = image.shape[:2] 993 994 def bbox_ioa(box1, box2): 995 # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2 996 box2 = box2.transpose() 997 998 # Get the coordinates of bounding boxes 999 b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] 1000 b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] 1001 1002 # Intersection area 1003 inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \ 1004 (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0) 1005 1006 # box2 area 1007 box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16 1008 1009 # Intersection over box2 area 1010 return inter_area / box2_area 1011 1012 # create random masks 1013 scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction 1014 for s in scales: 1015 mask_h = random.randint(1, int(h * s)) 1016 mask_w = random.randint(1, int(w * s)) 1017 1018 # box 1019 xmin = max(0, random.randint(0, w) - mask_w // 2) 1020 ymin = max(0, random.randint(0, h) - mask_h // 2) 1021 xmax = min(w, xmin + mask_w) 1022 ymax = min(h, ymin + mask_h) 1023 1024 # apply random color mask 1025 image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)] 1026 1027 # return unobscured labels 1028 if len(labels) and s > 0.03: 1029 box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32) 1030 ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area 1031 labels = labels[ioa < 0.60] # remove >60% obscured labels 1032 1033 return labels 1034 1035 1036 def create_folder(path='./new'): 1037 # Create folder 1038 if os.path.exists(path): 1039 shutil.rmtree(path) # delete output folder 1040 os.makedirs(path) # make new output folder 1041 1042 1043 def flatten_recursive(path='../coco128'): 1044 # Flatten a recursive directory by bringing all files to top level 1045 new_path = Path(path + '_flat') 1046 create_folder(new_path) 1047 for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)): 1048 shutil.copyfile(file, new_path / Path(file).name) 1049 1050 1051 def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128') 1052 # Convert detection dataset into classification dataset, with one directory per class 1053 1054 path = Path(path) # images dir 1055 shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing 1056 files = list(path.rglob('*.*')) 1057 n = len(files) # number of files 1058 for im_file in tqdm(files, total=n): 1059 if im_file.suffix[1:] in img_formats: 1060 # image 1061 im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB 1062 h, w = im.shape[:2] 1063 1064 # labels 1065 lb_file = Path(img2label_paths([str(im_file)])[0]) 1066 if Path(lb_file).exists(): 1067 with open(lb_file, 'r') as f: 1068 lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels 1069 1070 for j, x in enumerate(lb): 1071 c = int(x[0]) # class 1072 f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename 1073 if not f.parent.is_dir(): 1074 f.parent.mkdir(parents=True) 1075 1076 b = x[1:] * [w, h, w, h] # box 1077 # b[2:] = b[2:].max() # rectangle to square 1078 b[2:] = b[2:] * 1.2 + 3 # pad 1079 b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int) 1080 1081 b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image 1082 b[[1, 3]] = np.clip(b[[1, 3]], 0, h) 1083 assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}' 1084 1085 1086 def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0)): # from utils.datasets import *; autosplit('../coco128') 1087 """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files 1088 # Arguments 1089 path: Path to images directory 1090 weights: Train, val, test weights (list) 1091 """ 1092 path = Path(path) # images dir 1093 files = list(path.rglob('*.*')) 1094 n = len(files) # number of files 1095 indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split 1096 txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files 1097 [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing 1098 for i, img in tqdm(zip(indices, files), total=n): 1099 if img.suffix[1:] in img_formats: 1100 with open(path / txt[i], 'a') as f: 1101 f.write(str(img) + '\n') # add image to txt file
CV&DL