YOLOV5源码解读-datasets.py

  下载好yolov5源码后,打开文件yolov5-master\utils\datasets.py,我这边下载的是最新源码。尽管yolov5的作者一直在更新代码,但都是小更。今天主要看看datasets.py中的dataloader中的内容。YOLOV4相对于V5,V4偏向于学术,V5偏向于应用,原理差不多;但是V5做得很完善,把能想到的应用场景都想到了,代码简洁、安全、可靠。

  在datasets.py文件中只要有以下内容:

1、支持从本地、摄像头读取图片,也可以从相机在线读取;

2、各种数据增强(MixUp、Cutout、CutMix、Mosaic、随机擦除、打补丁、HSV空间增强等),参考我这篇博客:

https://www.cnblogs.com/winslam/p/14417867.html

3、数据的快速读取(可以从cache中读取)、数据的校验

(注:源码中,数据的增强,例如:反转、马赛克拼接,相关操作导致图像发生改变,你的label也要改变,这个作者做好了;如果你自己要编辑一些增强功能,记得也要改下label,这个label还是很容易改错的)

 

 

 

 

 

   可以看到,数据加载作者硬是干了1000多行代码,其中最重要的函数是:

 1 #功能:读取、验证数据和标签
 2 # path:数据集路径
 3 # img_size:640*640
 4 # batch_size:咱们显卡8G勉强吃得消    
 5 # hyp:训练过程超参数,eg:学习率、衰减、动量?
 6 # rect:训练过程中会将w!=h的图像padding成“正方形”,这个参数是该功能的控制开关(长方形似乎也可以训练吧,只不过慢一点)
 7 # cache_images:原本每次遍历标签比较慢,而缓存为cache文件后可以加快读取数据速度!
 8 # single_cls:针对单类别任务时,将label置为0
 9 # stride:降采样的系数
10 # pad:功能比较多,eg:Mosica合成四张图的时候,用于补全边角等位置的像素
11 
12 class LoadImagesAndLabels(Dataset):  # for training/testing

  我把所有注释放到下面,有兴趣可以看看:

   1 # Dataset utils and dataloaders
   2 
   3 import glob
   4 import logging
   5 import math
   6 import os
   7 import random
   8 import shutil
   9 import time
  10 from itertools import repeat
  11 from multiprocessing.pool import ThreadPool
  12 from pathlib import Path
  13 from threading import Thread
  14 
  15 import cv2
  16 import numpy as np
  17 import torch
  18 import torch.nn.functional as F
  19 from PIL import Image, ExifTags
  20 from torch.utils.data import Dataset
  21 from tqdm import tqdm
  22 
  23 from utils.general import xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, resample_segments, \
  24     clean_str
  25 from utils.torch_utils import torch_distributed_zero_first
  26 
  27 # Parameters
  28 help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  29 img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp']  # acceptable image suffixes
  30 vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv']  # acceptable video suffixes
  31 logger = logging.getLogger(__name__)
  32 
  33 # Get orientation exif tag
  34 for orientation in ExifTags.TAGS.keys():
  35     if ExifTags.TAGS[orientation] == 'Orientation':
  36         break
  37 
  38 # 用于验证数据集
  39 def get_hash(files):
  40     # Returns a single hash value of a list of files
  41     return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  42 
  43 
  44 def exif_size(img):
  45     # Returns exif-corrected PIL size
  46     s = img.size  # (width, height)
  47     try:
  48         rotation = dict(img._getexif().items())[orientation]
  49         if rotation == 6:  # rotation 270
  50             s = (s[1], s[0])
  51         elif rotation == 8:  # rotation 90
  52             s = (s[1], s[0])
  53     except:
  54         pass
  55 
  56     return s
  57 
  58 
  59 def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  60                       rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
  61     # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  62     with torch_distributed_zero_first(rank):
  63         dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  64                                       augment=augment,  # augment images
  65                                       hyp=hyp,  # augmentation hyperparameters
  66                                       rect=rect,  # rectangular training
  67                                       cache_images=cache,
  68                                       single_cls=opt.single_cls,
  69                                       stride=int(stride),
  70                                       pad=pad,
  71                                       image_weights=image_weights,
  72                                       prefix=prefix)
  73 
  74     batch_size = min(batch_size, len(dataset))
  75     nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers])  # number of workers
  76     sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  77     loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  78     # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  79     dataloader = loader(dataset,
  80                         batch_size=batch_size,
  81                         num_workers=nw,
  82                         sampler=sampler,
  83                         pin_memory=True,
  84                         collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  85     return dataloader, dataset
  86 
  87 
  88 class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  89     """ Dataloader that reuses workers
  90 
  91     Uses same syntax as vanilla DataLoader
  92     """
  93 
  94     def __init__(self, *args, **kwargs):
  95         super().__init__(*args, **kwargs)
  96         object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  97         self.iterator = super().__iter__()
  98 
  99     def __len__(self):
 100         return len(self.batch_sampler.sampler)
 101 
 102     def __iter__(self):
 103         for i in range(len(self)):
 104             yield next(self.iterator)
 105 
 106 
 107 class _RepeatSampler(object):
 108     """ Sampler that repeats forever
 109 
 110     Args:
 111         sampler (Sampler)
 112     """
 113 
 114     def __init__(self, sampler):
 115         self.sampler = sampler
 116 
 117     def __iter__(self):
 118         while True:
 119             yield from iter(self.sampler)
 120 
 121 # 加载图片
 122 class LoadImages:  # for inference
 123     def __init__(self, path, img_size=640, stride=32):
 124         p = str(Path(path).absolute())  # os-agnostic absolute path
 125         if '*' in p:
 126             files = sorted(glob.glob(p, recursive=True))  # glob
 127         elif os.path.isdir(p):
 128             files = sorted(glob.glob(os.path.join(p, '*.*')))  # dir
 129         elif os.path.isfile(p):
 130             files = [p]  # files
 131         else:
 132             raise Exception(f'ERROR: {p} does not exist')
 133 
 134         images = [x for x in files if x.split('.')[-1].lower() in img_formats]
 135         videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
 136         ni, nv = len(images), len(videos)
 137 
 138         self.img_size = img_size
 139         self.stride = stride
 140         self.files = images + videos
 141         self.nf = ni + nv  # number of files
 142         self.video_flag = [False] * ni + [True] * nv
 143         self.mode = 'image'
 144         if any(videos):
 145             self.new_video(videos[0])  # new video
 146         else:
 147             self.cap = None
 148         assert self.nf > 0, f'No images or videos found in {p}. ' \
 149                             f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
 150 
 151     def __iter__(self):
 152         self.count = 0
 153         return self
 154 
 155     def __next__(self):
 156         if self.count == self.nf:
 157             raise StopIteration
 158         path = self.files[self.count]
 159 
 160         if self.video_flag[self.count]:
 161             # Read video
 162             self.mode = 'video'
 163             ret_val, img0 = self.cap.read()
 164             if not ret_val:
 165                 self.count += 1
 166                 self.cap.release()
 167                 if self.count == self.nf:  # last video
 168                     raise StopIteration
 169                 else:
 170                     path = self.files[self.count]
 171                     self.new_video(path)
 172                     ret_val, img0 = self.cap.read()
 173 
 174             self.frame += 1
 175             print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
 176 
 177         else:
 178             # Read image
 179             self.count += 1
 180             img0 = cv2.imread(path)  # BGR
 181             assert img0 is not None, 'Image Not Found ' + path
 182             print(f'image {self.count}/{self.nf} {path}: ', end='')
 183 
 184         # Padded resize
 185         img = letterbox(img0, self.img_size, stride=self.stride)[0]
 186 
 187         # Convert
 188         img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
 189         img = np.ascontiguousarray(img)
 190 
 191         return path, img, img0, self.cap
 192 
 193     def new_video(self, path):
 194         self.frame = 0
 195         self.cap = cv2.VideoCapture(path)
 196         self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
 197 
 198     def __len__(self):
 199         return self.nf  # number of files
 200 
 201 # 从webCam中获取图片
 202 class LoadWebcam:  # for inference
 203     def __init__(self, pipe='0', img_size=640, stride=32):
 204         self.img_size = img_size
 205         self.stride = stride
 206 
 207         if pipe.isnumeric():
 208             pipe = eval(pipe)  # local camera
 209         # pipe = 'rtsp://192.168.1.64/1'  # IP camera
 210         # pipe = 'rtsp://username:password@192.168.1.64/1'  # IP camera with login
 211         # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg'  # IP golf camera
 212 
 213         self.pipe = pipe
 214         self.cap = cv2.VideoCapture(pipe)  # video capture object
 215         self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3)  # set buffer size
 216 
 217     def __iter__(self):
 218         self.count = -1
 219         return self
 220 
 221     def __next__(self):
 222         self.count += 1
 223         if cv2.waitKey(1) == ord('q'):  # q to quit
 224             self.cap.release()
 225             cv2.destroyAllWindows()
 226             raise StopIteration
 227 
 228         # Read frame
 229         if self.pipe == 0:  # local camera
 230             ret_val, img0 = self.cap.read()
 231             img0 = cv2.flip(img0, 1)  # flip left-right
 232         else:  # IP camera
 233             n = 0
 234             while True:
 235                 n += 1
 236                 self.cap.grab()
 237                 if n % 30 == 0:  # skip frames
 238                     ret_val, img0 = self.cap.retrieve()
 239                     if ret_val:
 240                         break
 241 
 242         # Print
 243         assert ret_val, f'Camera Error {self.pipe}'
 244         img_path = 'webcam.jpg'
 245         print(f'webcam {self.count}: ', end='')
 246 
 247         # Padded resize
 248         img = letterbox(img0, self.img_size, stride=self.stride)[0]
 249 
 250         # Convert
 251         img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
 252         img = np.ascontiguousarray(img)
 253 
 254         return img_path, img, img0, None
 255 
 256     def __len__(self):
 257         return 0
 258 
 259 # 加载视频流
 260 class LoadStreams:  # multiple IP or RTSP cameras
 261     def __init__(self, sources='streams.txt', img_size=640, stride=32):
 262         self.mode = 'stream'
 263         self.img_size = img_size
 264         self.stride = stride
 265 
 266         if os.path.isfile(sources):
 267             with open(sources, 'r') as f:
 268                 sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
 269         else:
 270             sources = [sources]
 271 
 272         n = len(sources)
 273         self.imgs = [None] * n
 274         self.sources = [clean_str(x) for x in sources]  # clean source names for later
 275         for i, s in enumerate(sources):
 276             # Start the thread to read frames from the video stream
 277             print(f'{i + 1}/{n}: {s}... ', end='')
 278             cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s)
 279             assert cap.isOpened(), f'Failed to open {s}'
 280             w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 281             h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 282             fps = cap.get(cv2.CAP_PROP_FPS) % 100
 283             _, self.imgs[i] = cap.read()  # guarantee first frame
 284             thread = Thread(target=self.update, args=([i, cap]), daemon=True)
 285             print(f' success ({w}x{h} at {fps:.2f} FPS).')
 286             thread.start()
 287         print('')  # newline
 288 
 289         # check for common shapes
 290         s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0)  # shapes
 291         self.rect = np.unique(s, axis=0).shape[0] == 1  # rect inference if all shapes equal
 292         if not self.rect:
 293             print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
 294 
 295     def update(self, index, cap):
 296         # Read next stream frame in a daemon thread
 297         n = 0
 298         while cap.isOpened():
 299             n += 1
 300             # _, self.imgs[index] = cap.read()
 301             cap.grab()
 302             if n == 4:  # read every 4th frame
 303                 success, im = cap.retrieve()
 304                 self.imgs[index] = im if success else self.imgs[index] * 0
 305                 n = 0
 306             time.sleep(0.01)  # wait time
 307 
 308     def __iter__(self):
 309         self.count = -1
 310         return self
 311 
 312     def __next__(self):
 313         self.count += 1
 314         img0 = self.imgs.copy()
 315         if cv2.waitKey(1) == ord('q'):  # q to quit
 316             cv2.destroyAllWindows()
 317             raise StopIteration
 318 
 319         # Letterbox
 320         img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
 321 
 322         # Stack
 323         img = np.stack(img, 0)
 324 
 325         # Convert
 326         img = img[:, :, :, ::-1].transpose(0, 3, 1, 2)  # BGR to RGB, to bsx3x416x416
 327         img = np.ascontiguousarray(img)
 328 
 329         return self.sources, img, img0, None
 330 
 331     def __len__(self):
 332         return 0  # 1E12 frames = 32 streams at 30 FPS for 30 years
 333 
 334 
 335 def img2label_paths(img_paths):
 336     # Define label paths as a function of image paths
 337     # 只需要将label框的信息读取进来
 338     sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep  # /images/, /labels/ substrings
 339     return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
 340 
 341 #功能:读取、验证数据和标签
 342 # path:数据集路径
 343 # img_size:640*640
 344 # batch_size:咱们显卡8G勉强吃得消    
 345 # hyp:训练过程超参数,eg:学习率、衰减、动量?
 346 # rect:训练过程中会将w!=h的图像padding成“正方形”,这个参数是该功能的控制开关(长方形似乎也可以训练吧,只不过慢一点)
 347 # cache_images:原本每次遍历标签比较慢,而缓存为cache文件后可以加快读取数据速度!
 348 # single_cls:针对单类别任务时,将label置为0
 349 # stride:降采样的系数
 350 # pad:功能比较多,eg:Mosica合成四张图的时候,用于补全边角等位置的像素
 351 
 352 class LoadImagesAndLabels(Dataset):  # for training/testing
 353     def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
 354                  cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
 355         self.img_size = img_size
 356         self.augment = augment
 357         self.hyp = hyp
 358         self.image_weights = image_weights
 359         self.rect = False if image_weights else rect
 360         # true or false
 361         self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
 362         # 马赛克拼接过程中,希望如何拼接?????
 363         self.mosaic_border = [-img_size // 2, -img_size // 2]
 364         self.stride = stride
 365         self.path = path
 366 
 367         try:
 368             f = []  # image files
 369             for p in path if isinstance(path, list) else [path]:
 370                 p = Path(p)  # os-agnostic,数据集所在路径
 371                 if p.is_dir():  # dir
 372                     # 所有格式的图片全部找到
 373                     f += glob.glob(str(p / '**' / '*.*'), recursive=True)
 374                     # f = list(p.rglob('**/*.*'))  # pathlib
 375                 elif p.is_file():  # file
 376                     with open(p, 'r') as t:
 377                         t = t.read().strip().splitlines()
 378                         parent = str(p.parent) + os.sep
 379                         f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path
 380                         # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
 381                 else:
 382                     raise Exception(f'{prefix}{p} does not exist')
 383             # n = len(self.img_files) 图片数量
 384             self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
 385             # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats])  # pathlib
 386             assert self.img_files, f'{prefix}No images found'
 387         except Exception as e:
 388             raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
 389 
 390         # Check cache 后续读取文件会优先检查缓存文件,如果有,就读。
 391         self.label_files = img2label_paths(self.img_files)  # labels
 392         cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')  # cached labels
 393         #如果有缓存,那就直接读取就行
 394         if cache_path.is_file(): 
 395             cache, exists = torch.load(cache_path), True  # load
 396             if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache:  # changed
 397                 cache, exists = self.cache_labels(cache_path, prefix), False  # re-cache
 398         else:
 399             cache, exists = self.cache_labels(cache_path, prefix), False  # cache
 400 
 401         # Display cache
 402         nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupted, total
 403         if exists:
 404             d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
 405             tqdm(None, desc=prefix + d, total=n, initial=n)  # display cache results
 406         assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
 407 
 408         # Read cache
 409         cache.pop('hash')  # remove hash
 410         cache.pop('version')  # remove version
 411         # 从缓存中读取labels、框的shapes
 412         labels, shapes, self.segments = zip(*cache.values())
 413         self.labels = list(labels)
 414         self.shapes = np.array(shapes, dtype=np.float64)
 415         self.img_files = list(cache.keys())  # update
 416         self.label_files = img2label_paths(cache.keys())  # update
 417         if single_cls:
 418             for x in self.labels:
 419                 x[:, 0] = 0
 420         # 图片数量
 421         n = len(shapes)  # number of images
 422         # batch索引,eg:batch = 8 对应一个epoch
 423         bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index
 424         # 一个epoch有nb个batch
 425         nb = bi[-1] + 1  # number of batches
 426         self.batch = bi  # batch index of image
 427         self.n = n
 428         self.indices = range(n)
 429 
 430         # Rectangular Training
 431         if self.rect:
 432             # Sort by aspect ratio
 433             s = self.shapes  # wh
 434             ar = s[:, 1] / s[:, 0]  # aspect ratio
 435             irect = ar.argsort()
 436             self.img_files = [self.img_files[i] for i in irect]
 437             self.label_files = [self.label_files[i] for i in irect]
 438             self.labels = [self.labels[i] for i in irect]
 439             self.shapes = s[irect]  # wh
 440             ar = ar[irect]
 441 
 442             # Set training image shapes
 443             shapes = [[1, 1]] * nb
 444             for i in range(nb):
 445                 ari = ar[bi == i]
 446                 mini, maxi = ari.min(), ari.max()
 447                 if maxi < 1:
 448                     shapes[i] = [maxi, 1]
 449                 elif mini > 1:
 450                     shapes[i] = [1, 1 / mini]
 451 
 452             self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
 453 
 454         # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
 455         self.imgs = [None] * n
 456         if cache_images:
 457             gb = 0  # Gigabytes of cached images
 458             self.img_hw0, self.img_hw = [None] * n, [None] * n
 459             results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))  # 8 threads
 460             pbar = tqdm(enumerate(results), total=n)
 461             for i, x in pbar:
 462                 self.imgs[i], self.img_hw0[i], self.img_hw[i] = x  # img, hw_original, hw_resized = load_image(self, i)
 463                 gb += self.imgs[i].nbytes
 464                 pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
 465 
 466     def cache_labels(self, path=Path('./labels.cache'), prefix=''):
 467         # Cache dataset labels, check images and read shapes
 468         x = {}  # dict
 469         # 读取数据过程中显示一些错误信息,eg:上述可能读取到非图片文件,这时可以打印出错误信息
 470         # eg:丢失、找到(正常)、空的、重复
 471         nm, nf, ne, nc = 0, 0, 0, 0  # number missing, found, empty, duplicate
 472         # 控制台的进度条!
 473         pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
 474         for i, (im_file, lb_file) in enumerate(pbar):
 475             try:
 476                 # verify images 验证图片,eg:不能太小
 477                 im = Image.open(im_file)
 478                 im.verify()  # PIL verify
 479                 shape = exif_size(im)  # image size
 480                 segments = []  # instance segments
 481                 assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
 482                 assert im.format.lower() in img_formats, f'invalid image format {im.format}'
 483 
 484                 # verify labels 验证标签
 485                 if os.path.isfile(lb_file):
 486                     nf += 1  # label found
 487                     with open(lb_file, 'r') as f:
 488                         l = [x.split() for x in f.read().strip().splitlines()]
 489                         if any([len(x) > 8 for x in l]):  # is segment
 490                             classes = np.array([x[0] for x in l], dtype=np.float32)
 491                             segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l]  # (cls, xy1...)
 492                             l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh)
 493                         l = np.array(l, dtype=np.float32)
 494                     if len(l):
 495                         assert l.shape[1] == 5, 'labels require 5 columns each'
 496                         assert (l >= 0).all(), 'negative labels'
 497                         assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
 498                         assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
 499                     else:
 500                         ne += 1  # label empty
 501                         l = np.zeros((0, 5), dtype=np.float32)
 502                 else:
 503                     nm += 1  # label missing
 504                     l = np.zeros((0, 5), dtype=np.float32)
 505                 x[im_file] = [l, shape, segments]
 506             except Exception as e:
 507                 nc += 1
 508                 print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
 509 
 510             pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \
 511                         f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
 512 
 513         if nf == 0:
 514             print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
 515 
 516         x['hash'] = get_hash(self.label_files + self.img_files)
 517         x['results'] = nf, nm, ne, nc, i + 1
 518         x['version'] = 0.1  # cache version
 519         torch.save(x, path)  # save for next time
 520         logging.info(f'{prefix}New cache created: {path}')
 521         return x
 522 
 523     def __len__(self):
 524         return len(self.img_files)
 525 
 526     # def __iter__(self):
 527     #     self.count = -1
 528     #     print('ran dataset iter')
 529     #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
 530     #     return self
 531     # 这个函数是训练时,图片向网络中传的时候执行
 532     def __getitem__(self, index):
 533         index = self.indices[index]  # linear, shuffled, or image_weights
 534 
 535         hyp = self.hyp
 536         mosaic = self.mosaic and random.random() < hyp['mosaic']
 537         # 马赛克增强
 538         if mosaic:
 539             # Load mosaic
 540             # load_mosaic:如何将四张图合成一张图
 541             img, labels = load_mosaic(self, index)
 542             shapes = None
 543 
 544             # MixUp https://arxiv.org/pdf/1710.09412.pdf
 545             if random.random() < hyp['mixup']:
 546                 img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
 547                 r = np.random.beta(8.0, 8.0)  # mixup ratio, alpha=beta=8.0
 548                 img = (img * r + img2 * (1 - r)).astype(np.uint8)
 549                 labels = np.concatenate((labels, labels2), 0)
 550 
 551         else:
 552             # Load image
 553             img, (h0, w0), (h, w) = load_image(self, index)
 554 
 555             # Letterbox
 556             shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shape
 557             img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
 558             shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling
 559 
 560             labels = self.labels[index].copy()
 561             if labels.size:  # normalized xywh to pixel xyxy format
 562                 labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
 563 
 564         if self.augment:
 565             # Augment imagespace
 566             if not mosaic:
 567                 # 这里数据增强了,标签也会改变,所以也会返回label
 568                 img, labels = random_perspective(img, labels,
 569                                                  degrees=hyp['degrees'],
 570                                                  translate=hyp['translate'],
 571                                                  scale=hyp['scale'],
 572                                                  shear=hyp['shear'],
 573                                                  perspective=hyp['perspective'])
 574 
 575             # Augment colorspace 增强色彩空间,色调、饱和度、亮度
 576             augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
 577 
 578             # Apply cutouts
 579             # if random.random() < 0.9:
 580             #     labels = cutout(img, labels)
 581 
 582         nL = len(labels)  # number of labels
 583         if nL:
 584             # 用xywh替换xyxy,用于表示框的几何信息
 585             labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])  # convert xyxy to xywh
 586             labels[:, [2, 4]] /= img.shape[0]  # normalized height 0-1
 587             labels[:, [1, 3]] /= img.shape[1]  # normalized width 0-1
 588 
 589         if self.augment:
 590             # flip up-down 图像反转增强
 591             if random.random() < hyp['flipud']:
 592                 img = np.flipud(img)
 593                 if nL:
 594                     # 图片反转了,标签也是
 595                     labels[:, 2] = 1 - labels[:, 2]
 596 
 597             # flip left-right
 598             if random.random() < hyp['fliplr']:
 599                 img = np.fliplr(img)
 600                 if nL:
 601                     labels[:, 1] = 1 - labels[:, 1]
 602 
 603         labels_out = torch.zeros((nL, 6))
 604         if nL:
 605             labels_out[:, 1:] = torch.from_numpy(labels)
 606 
 607         # Convert
 608         img = img[:, :, ::-1].transpose(2, 0, 1)  # opencv中BGR 转为 pytorch中的 RGB, to 3x416x416
 609         img = np.ascontiguousarray(img)
 610 
 611         return torch.from_numpy(img), labels_out, self.img_files[index], shapes
 612 
 613     @staticmethod
 614     def collate_fn(batch):
 615         img, label, path, shapes = zip(*batch)  # transposed
 616         for i, l in enumerate(label):
 617             l[:, 0] = i  # add target image index for build_targets()
 618         return torch.stack(img, 0), torch.cat(label, 0), path, shapes
 619 
 620     @staticmethod
 621     def collate_fn4(batch):
 622         img, label, path, shapes = zip(*batch)  # transposed
 623         n = len(shapes) // 4
 624         img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
 625 
 626         ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
 627         wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
 628         s = torch.tensor([[1, 1, .5, .5, .5, .5]])  # scale
 629         for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
 630             i *= 4
 631             if random.random() < 0.5:
 632                 im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
 633                     0].type(img[i].type())
 634                 l = label[i]
 635             else:
 636                 im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
 637                 l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
 638             img4.append(im)
 639             label4.append(l)
 640 
 641         for i, l in enumerate(label4):
 642             l[:, 0] = i  # add target image index for build_targets()
 643 
 644         return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
 645 
 646 
 647 # Ancillary functions --------------------------------------------------------------------------------------------------
 648 def load_image(self, index):
 649     # loads 1 image from dataset, returns img, original hw, resized hw
 650     img = self.imgs[index]
 651     if img is None:  # not cached
 652         path = self.img_files[index]
 653         img = cv2.imread(path)  # BGR
 654         assert img is not None, 'Image Not Found ' + path
 655         h0, w0 = img.shape[:2]  # orig hw
 656         r = self.img_size / max(h0, w0)  # resize image to img_size
 657         if r != 1:  # always resize down, only resize up if training with augmentation
 658             interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
 659             img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
 660         return img, (h0, w0), img.shape[:2]  # img, hw_original, hw_resized
 661     else:
 662         return self.imgs[index], self.img_hw0[index], self.img_hw[index]  # img, hw_original, hw_resized
 663 
 664 # 色调、饱和度、亮度空间增强
 665 def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
 666     r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1  # random gains
 667     hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
 668     dtype = img.dtype  # uint8
 669 
 670     x = np.arange(0, 256, dtype=np.int16)
 671     lut_hue = ((x * r[0]) % 180).astype(dtype)
 672     lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
 673     lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
 674 
 675     img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
 676     cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)  # no return needed
 677 
 678 
 679 def hist_equalize(img, clahe=True, bgr=False):
 680     # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
 681     yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
 682     if clahe:
 683         c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
 684         yuv[:, :, 0] = c.apply(yuv[:, :, 0])
 685     else:
 686         yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0])  # equalize Y channel histogram
 687     return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB)  # convert YUV image to RGB
 688 
 689 # 马赛克增强
 690 # 四张图拼接到一起,标签也要重新计算
 691 def load_mosaic(self, index):
 692     # loads images in a 4-mosaic
 693 
 694     labels4, segments4 = [], [] #list
 695     s = self.img_size # 640
 696     # 四张图拼接交汇处,交汇为一个点
 697     # 这个点我们设置为随机点p(yc, xc),然后四张图怼着这个点拼接
 698     # p点在图上被限制在一个中心区域,例如:这个区域是正方形,边长为:
 699     # (-x, 2 * s + x),其中x = 320,2 * 640 - 320
 700     #
 701     yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border]  # mosaic center x, y
 702     # 传进来一个索引为index的图片
 703     # 我们需要随机选取另外三张图片,用于拼接
 704     indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
 705     for i, index in enumerate(indices): # 这里循环四次,处理四张图
 706         # Load image 
 707         # 四张图放到大图中
 708         img, _, (h, w) = load_image(self, index)
 709         #
 710         # place img in img4
 711         # 1.计算初始化大图;
 712         # 2.计算当前图片放在大图中什么位置;
 713         # 3.计算在小图中去哪一部分放到大图中;
 714         if i == 0:  # top left
 715             # 初始化大图
 716             # s = 640 填充像素值为:114 
 717             img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
 718             # 大图
 719             # (x1a, y1a):左上角框的左上顶点,(w, h)是框的宽高,max防止越界
 720             # (x2a, y2a):右下角...
 721             x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
 722             # 小图
 723             x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
 724         elif i == 1:  # top right
 725             x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
 726             x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
 727         elif i == 2:  # bottom left
 728             x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
 729             x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
 730         elif i == 3:  # bottom right
 731             x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
 732             x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
 733         # 将小图img中ROI放到大图中指定位置
 734         img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
 735         # 用于大图padding 
 736         padw = x1a - x1b
 737         padh = y1a - y1b
 738 
 739         # 重新计算标值,由于子图都放入大图中
 740         # Labels
 741         labels, segments = self.labels[index].copy(), self.segments[index].copy()
 742         if labels.size:
 743             labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)  # normalized xywh to pixel xyxy format
 744             segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
 745         labels4.append(labels)
 746         segments4.extend(segments)
 747 
 748     # Concat/clip labels 坐标值计算完之后可能越界,调整坐标值,让他们都在大图中。
 749     labels4 = np.concatenate(labels4, 0)
 750     for x in (labels4[:, 1:], *segments4):
 751         np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
 752     # img4, labels4 = replicate(img4, labels4)  # replicate
 753     
 754     # 通过上述步骤,得到了一张大图
 755     # 对整合的大图进行随机旋转、平移、缩放、裁剪
 756     # Augment 数据增强
 757     img4, labels4 = random_perspective(img4, labels4, segments4,
 758                                        degrees=self.hyp['degrees'], # 旋转
 759                                        translate=self.hyp['translate'], # 平移
 760                                        scale=self.hyp['scale'], # 缩放
 761                                        shear=self.hyp['shear'], # 剪切
 762                                        perspective=self.hyp['perspective'], # 放射
 763                                        border=self.mosaic_border)  # border to remove
 764 
 765     return img4, labels4
 766 
 767 
 768 def load_mosaic9(self, index):
 769     # loads images in a 9-mosaic
 770 
 771     labels9, segments9 = [], []
 772     s = self.img_size
 773     indices = [index] + random.choices(self.indices, k=8)  # 8 additional image indices
 774     for i, index in enumerate(indices):
 775         # Load image
 776         img, _, (h, w) = load_image(self, index)
 777 
 778         # place img in img9
 779         if i == 0:  # center
 780             img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
 781             h0, w0 = h, w
 782             c = s, s, s + w, s + h  # xmin, ymin, xmax, ymax (base) coordinates
 783         elif i == 1:  # top
 784             c = s, s - h, s + w, s
 785         elif i == 2:  # top right
 786             c = s + wp, s - h, s + wp + w, s
 787         elif i == 3:  # right
 788             c = s + w0, s, s + w0 + w, s + h
 789         elif i == 4:  # bottom right
 790             c = s + w0, s + hp, s + w0 + w, s + hp + h
 791         elif i == 5:  # bottom
 792             c = s + w0 - w, s + h0, s + w0, s + h0 + h
 793         elif i == 6:  # bottom left
 794             c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
 795         elif i == 7:  # left
 796             c = s - w, s + h0 - h, s, s + h0
 797         elif i == 8:  # top left
 798             c = s - w, s + h0 - hp - h, s, s + h0 - hp
 799 
 800         padx, pady = c[:2]
 801         x1, y1, x2, y2 = [max(x, 0) for x in c]  # allocate coords
 802 
 803         # Labels
 804         labels, segments = self.labels[index].copy(), self.segments[index].copy()
 805         if labels.size:
 806             labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady)  # normalized xywh to pixel xyxy format
 807             segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
 808         labels9.append(labels)
 809         segments9.extend(segments)
 810 
 811         # Image
 812         img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:]  # img9[ymin:ymax, xmin:xmax]
 813         hp, wp = h, w  # height, width previous
 814 
 815     # Offset
 816     yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border]  # mosaic center x, y
 817     img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
 818 
 819     # Concat/clip labels
 820     labels9 = np.concatenate(labels9, 0)
 821     labels9[:, [1, 3]] -= xc
 822     labels9[:, [2, 4]] -= yc
 823     c = np.array([xc, yc])  # centers
 824     segments9 = [x - c for x in segments9]
 825 
 826     for x in (labels9[:, 1:], *segments9):
 827         np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
 828     # img9, labels9 = replicate(img9, labels9)  # replicate
 829 
 830     # Augment
 831     img9, labels9 = random_perspective(img9, labels9, segments9,
 832                                        degrees=self.hyp['degrees'],
 833                                        translate=self.hyp['translate'],
 834                                        scale=self.hyp['scale'],
 835                                        shear=self.hyp['shear'],
 836                                        perspective=self.hyp['perspective'],
 837                                        border=self.mosaic_border)  # border to remove
 838 
 839     return img9, labels9
 840 
 841 
 842 def replicate(img, labels):
 843     # Replicate labels
 844     h, w = img.shape[:2]
 845     boxes = labels[:, 1:].astype(int)
 846     x1, y1, x2, y2 = boxes.T
 847     s = ((x2 - x1) + (y2 - y1)) / 2  # side length (pixels)
 848     for i in s.argsort()[:round(s.size * 0.5)]:  # smallest indices
 849         x1b, y1b, x2b, y2b = boxes[i]
 850         bh, bw = y2b - y1b, x2b - x1b
 851         yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw))  # offset x, y
 852         x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
 853         img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
 854         labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
 855 
 856     return img, labels
 857 
 858 
 859 def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
 860     # Resize and pad image while meeting stride-multiple constraints
 861     shape = img.shape[:2]  # current shape [height, width]
 862     if isinstance(new_shape, int):
 863         new_shape = (new_shape, new_shape)
 864 
 865     # Scale ratio (new / old)
 866     r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
 867     if not scaleup:  # only scale down, do not scale up (for better test mAP)
 868         r = min(r, 1.0)
 869 
 870     # Compute padding
 871     ratio = r, r  # width, height ratios
 872     new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
 873     dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
 874     if auto:  # minimum rectangle
 875         dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
 876     elif scaleFill:  # stretch
 877         dw, dh = 0.0, 0.0
 878         new_unpad = (new_shape[1], new_shape[0])
 879         ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios
 880 
 881     dw /= 2  # divide padding into 2 sides
 882     dh /= 2
 883 
 884     if shape[::-1] != new_unpad:  # resize
 885         img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
 886     top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
 887     left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
 888     img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
 889     return img, ratio, (dw, dh)
 890 
 891 # 这里主要是采用opencv去做图像增强,而不是pytorch的模块
 892 def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
 893                        border=(0, 0)):
 894     # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
 895     # targets = [cls, xyxy]
 896 
 897     height = img.shape[0] + border[0] * 2  # shape(h,w,c)
 898     width = img.shape[1] + border[1] * 2
 899 
 900     # Center
 901     C = np.eye(3)
 902     C[0, 2] = -img.shape[1] / 2  # x translation (pixels)
 903     C[1, 2] = -img.shape[0] / 2  # y translation (pixels)
 904 
 905     # Perspective 随机构造透视矩阵
 906     P = np.eye(3)
 907     P[2, 0] = random.uniform(-perspective, perspective)  # x perspective (about y)
 908     P[2, 1] = random.uniform(-perspective, perspective)  # y perspective (about x)
 909 
 910     # Rotation and Scale 随机构造旋转、缩放
 911     R = np.eye(3)
 912     a = random.uniform(-degrees, degrees)
 913     # a += random.choice([-180, -90, 0, 90])  # add 90deg rotations to small rotations
 914     s = random.uniform(1 - scale, 1 + scale)
 915     # s = 2 ** random.uniform(-scale, scale)
 916     R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
 917 
 918     # Shear 随机构造裁剪
 919     S = np.eye(3)
 920     S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180)  # x shear (deg)
 921     S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180)  # y shear (deg)
 922 
 923     # Translation 随机构造平移
 924     T = np.eye(3)
 925     T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width  # x translation (pixels)
 926     T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height  # y translation (pixels)
 927 
 928     # Combined rotation matrix
 929     # 将上述所有变换集合到 M 矩阵
 930     M = T @ S @ R @ P @ C  # order of operations (right to left) is IMPORTANT
 931     if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any():  # image changed
 932         if perspective:
 933             img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
 934         else:  # affine
 935             img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
 936 
 937     # Visualize
 938     # import matplotlib.pyplot as plt
 939     # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
 940     # ax[0].imshow(img[:, :, ::-1])  # base
 941     # ax[1].imshow(img2[:, :, ::-1])  # warped
 942 
 943     # Transform label coordinates 数据增强导致label也变了,下面就是计算新的label
 944     n = len(targets)
 945     if n:
 946         use_segments = any(x.any() for x in segments)
 947         new = np.zeros((n, 4))
 948         if use_segments:  # warp segments
 949             segments = resample_segments(segments)  # upsample
 950             for i, segment in enumerate(segments):
 951                 xy = np.ones((len(segment), 3))
 952                 xy[:, :2] = segment
 953                 xy = xy @ M.T  # transform
 954                 xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]  # perspective rescale or affine
 955 
 956                 # clip
 957                 new[i] = segment2box(xy, width, height)
 958 
 959         else:  # warp boxes
 960             xy = np.ones((n * 4, 3))
 961             xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2)  # x1y1, x2y2, x1y2, x2y1
 962             xy = xy @ M.T  # transform
 963             xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8)  # perspective rescale or affine
 964 
 965             # create new boxes
 966             x = xy[:, [0, 2, 4, 6]]
 967             y = xy[:, [1, 3, 5, 7]]
 968             new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
 969 
 970             # clip
 971             new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
 972             new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
 973 
 974         # filter candidates
 975         i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
 976         targets = targets[i]
 977         targets[:, 1:5] = new[i]
 978 
 979     return img, targets
 980 
 981 
 982 def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16):  # box1(4,n), box2(4,n)
 983     # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
 984     w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
 985     w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
 986     ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps))  # aspect ratio
 987     return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr)  # candidates
 988 
 989 # 随机裁剪增强
 990 def cutout(image, labels):
 991     # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
 992     h, w = image.shape[:2]
 993 
 994     def bbox_ioa(box1, box2):
 995         # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
 996         box2 = box2.transpose()
 997 
 998         # Get the coordinates of bounding boxes
 999         b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
1000         b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
1001 
1002         # Intersection area
1003         inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
1004                      (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
1005 
1006         # box2 area
1007         box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
1008 
1009         # Intersection over box2 area
1010         return inter_area / box2_area
1011 
1012     # create random masks
1013     scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16  # image size fraction
1014     for s in scales:
1015         mask_h = random.randint(1, int(h * s))
1016         mask_w = random.randint(1, int(w * s))
1017 
1018         # box
1019         xmin = max(0, random.randint(0, w) - mask_w // 2)
1020         ymin = max(0, random.randint(0, h) - mask_h // 2)
1021         xmax = min(w, xmin + mask_w)
1022         ymax = min(h, ymin + mask_h)
1023 
1024         # apply random color mask
1025         image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
1026 
1027         # return unobscured labels
1028         if len(labels) and s > 0.03:
1029             box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
1030             ioa = bbox_ioa(box, labels[:, 1:5])  # intersection over area
1031             labels = labels[ioa < 0.60]  # remove >60% obscured labels
1032 
1033     return labels
1034 
1035 
1036 def create_folder(path='./new'):
1037     # Create folder
1038     if os.path.exists(path):
1039         shutil.rmtree(path)  # delete output folder
1040     os.makedirs(path)  # make new output folder
1041 
1042 
1043 def flatten_recursive(path='../coco128'):
1044     # Flatten a recursive directory by bringing all files to top level
1045     new_path = Path(path + '_flat')
1046     create_folder(new_path)
1047     for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
1048         shutil.copyfile(file, new_path / Path(file).name)
1049 
1050 
1051 def extract_boxes(path='../coco128/'):  # from utils.datasets import *; extract_boxes('../coco128')
1052     # Convert detection dataset into classification dataset, with one directory per class
1053 
1054     path = Path(path)  # images dir
1055     shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None  # remove existing
1056     files = list(path.rglob('*.*'))
1057     n = len(files)  # number of files
1058     for im_file in tqdm(files, total=n):
1059         if im_file.suffix[1:] in img_formats:
1060             # image
1061             im = cv2.imread(str(im_file))[..., ::-1]  # BGR to RGB
1062             h, w = im.shape[:2]
1063 
1064             # labels
1065             lb_file = Path(img2label_paths([str(im_file)])[0])
1066             if Path(lb_file).exists():
1067                 with open(lb_file, 'r') as f:
1068                     lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32)  # labels
1069 
1070                 for j, x in enumerate(lb):
1071                     c = int(x[0])  # class
1072                     f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg'  # new filename
1073                     if not f.parent.is_dir():
1074                         f.parent.mkdir(parents=True)
1075 
1076                     b = x[1:] * [w, h, w, h]  # box
1077                     # b[2:] = b[2:].max()  # rectangle to square
1078                     b[2:] = b[2:] * 1.2 + 3  # pad
1079                     b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
1080 
1081                     b[[0, 2]] = np.clip(b[[0, 2]], 0, w)  # clip boxes outside of image
1082                     b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
1083                     assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
1084 
1085 
1086 def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0)):  # from utils.datasets import *; autosplit('../coco128')
1087     """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
1088     # Arguments
1089         path:       Path to images directory
1090         weights:    Train, val, test weights (list)
1091     """
1092     path = Path(path)  # images dir
1093     files = list(path.rglob('*.*'))
1094     n = len(files)  # number of files
1095     indices = random.choices([0, 1, 2], weights=weights, k=n)  # assign each image to a split
1096     txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt']  # 3 txt files
1097     [(path / x).unlink() for x in txt if (path / x).exists()]  # remove existing
1098     for i, img in tqdm(zip(indices, files), total=n):
1099         if img.suffix[1:] in img_formats:
1100             with open(path / txt[i], 'a') as f:
1101                 f.write(str(img) + '\n')  # add image to txt file

 

posted @ 2021-03-10 15:10  佚名12  阅读(249)  评论(0编辑  收藏  举报