YOLOv7 源码解读之数据读取
1. YOLOv7 代码组织结构
YOLOv7 代码结构
. ├── cfg(存放`yaml`格式定义的网络结构) │ ├── baseline(用来比较的) │ │ ├── r50-csp.yaml │ │ ├── x50-csp.yaml │ │ ├── yolor-csp-x.yaml │ │ ├── yolor-csp.yaml │ │ ├── yolor-d6.yaml │ │ ├── yolor-e6.yaml │ │ ├── yolor-p6.yaml │ │ ├── yolor-w6.yaml │ │ ├── yolov3-spp.yaml │ │ ├── yolov3.yaml │ │ └── yolov4-csp.yaml │ ├── deploy(部署时候使用的) │ │ ├── yolov7-d6.yaml │ │ ├── yolov7-e6e.yaml │ │ ├── yolov7-e6.yaml │ │ ├── yolov7-tiny-silu.yaml │ │ ├── yolov7-tiny.yaml │ │ ├── yolov7-w6.yaml │ │ ├── yolov7x.yaml │ │ └── yolov7.yaml │ └── training(训练时候使用的) │ ├── yolov7-d6.yaml │ ├── yolov7-e6e.yaml │ ├── yolov7-e6.yaml │ ├── yolov7-tiny.yaml │ ├── yolov7-w6.yaml │ ├── yolov7x.yaml │ └── yolov7.yaml ├── data() │ ├── coco.yaml(COCO 数据集信息) │ ├── hyp.scratch.custom.yaml(这四个都是模型训练时候的超参数) │ ├── hyp.scratch.p5.yaml │ ├── hyp.scratch.p6.yaml │ └── hyp.scratch.tiny.yaml ├── deploy(部署相关) │ └── triton-inference-server │ ├── boundingbox.py │ ├── client.py │ ├── data │ │ ├── dog.jpg │ │ └── dog_result.jpg │ ├── labels.py │ ├── processing.py │ ├── README.md │ └── render.py ├── detect.py(可直接运行的检测脚本) ├── export.py(可导出 TorchScript、CoreML、 TorchScript-Lite和ONNX) ├── figure(一些图片) │ ├── horses_prediction.jpg │ ├── mask.png │ ├── performance.png │ ├── pose.png │ ├── tennis_caption.png │ ├── tennis.jpg │ ├── tennis_panoptic.png │ └── tennis_semantic.jpg ├── hubconf.py(感觉暂时没啥用) ├── inference │ └── images │ ├── bus.jpg │ ├── horses.jpg │ ├── image1.jpg │ ├── image2.jpg │ ├── image3.jpg │ └── zidane.jpg ├── LICENSE.md ├── models(**重点**,存放网络结构) │ ├── common.py(一些网络中的组件) │ ├── experimental.py(一些可以实验的组件) │ ├── __init__.py │ └── yolo.py(网络结构定义,包括yaml 解析) ├── README.md ├── requirements.txt ├── runs(模型运行时候的输出) │ └── train │ ├── yolov7 │ │ ├── events.out.tfevents.1661764925.ai-ai-dev-az3-01.13526.0 │ │ ├── hyp.yaml │ │ ├── opt.yaml │ │ └── weights ├── scripts(获得COCO 数据集) │ └── get_coco.sh ├── test.py(测试模型指标) ├── tools(一些jupyter notebook 代码) │ ├── compare_YOLOv7e6_vs_YOLOv5x6_half.ipynb │ ├── compare_YOLOv7e6_vs_YOLOv5x6.ipynb │ ├── compare_YOLOv7_vs_YOLOv5m6_half.ipynb │ ├── compare_YOLOv7_vs_YOLOv5m6.ipynb │ ├── compare_YOLOv7_vs_YOLOv5s6.ipynb │ ├── instance.ipynb │ ├── keypoint.ipynb │ ├── reparameterization.ipynb │ ├── visualization.ipynb │ ├── YOLOv7CoreML.ipynb │ ├── YOLOv7-Dynamic-Batch-ONNXRUNTIME.ipynb │ ├── YOLOv7-Dynamic-Batch-TENSORRT.ipynb │ ├── YOLOv7onnx.ipynb │ └── YOLOv7trt.ipynb ├── train_aux.py(rain p6 odels) ├── train.py(train p5 models) ├── utils │ ├── activations.py(定义了很多激活函数) │ ├── add_nms.py │ ├── autoanchor.py │ ├── aws │ │ ├── __init__.py │ │ ├── mime.sh │ │ ├── resume.py │ │ └── userdata.sh │ ├── datasets.py(**重点**,数据的读取和加载) │ ├── general.py(一些通用函数) │ ├── google_app_engine │ │ ├── additional_requirements.txt │ │ ├── app.yaml │ │ └── Dockerfile │ ├── google_utils.py │ ├── __init__.py │ ├── loss.py(**定义损失**) │ ├── metrics.py(衡量指标) │ ├── plots.py(画图) │ ├── torch_utils.py(YOLOR PyTorch utils) │ └── wandb_logging │ ├── __init__.py │ ├── log_dataset.py │ └── wandb_utils.py └── yolov7.pt(预训练模型)
2. 数据读取
本文的目的是调试 COCO2017的数据集,之前我写过YOLOv5 训练 VOC 数据集的代码说明,https://blog.csdn.net/hymn1993/article/details/123664708 。
本文重新解读一下,但是没啥区别。
2.1 代码解读
程序入口:train.py
# Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, world_size=opt.world_size, workers=opt.workers, image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
注:augment = True,rect 为 False:
parser.add_argument('--rect', action='store_true', help='rectangular training')
我们训练时候不指定该参数,所以rect 为 False。rect: 是否开启矩形train/test,默认训练集关闭 ,验证集开启,可以加速。self.rect=True时,self.batch_shapes记载每个batch的shape(同一个batch的图片shape相同)。
utils/datasets.py::create_dataloader函数定义:
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''): """在train.py中被调用,用于生成 dataloader, dataset,testloader 自定义dataloader函数: 调用LoadImagesAndLabels获取数据集(包括数据增强) + 调用分布式采样器 DistributedSampler + 自定义InfiniteDataLoader 进行永久持续的采样数据 :param path: 图片数据加载路径 train/test 如: '../COCO2017/train2017.txt' :param imgsz: train/test图片尺寸(数据增强后大小) 如:640 :param batch_size: batch size 大小 如 32 :param stride: 模型最大stride 如 32 :param opt.single_cls: 数据集是否是单类别 默认False :param hyp: 超参列表dict 网络训练时的一些超参数,包括学习率等,这里主要用到里面一些关于数据增强(旋转、平移等)的系数 在命令行参数中传入 `--hyp` 来定义 :param augment: 是否要进行数据增强 训练时为 True :param cache: 是否 cache_images False :param pad: 设置矩形训练的shape时进行的填充 默认0.0 :param rect: 是否开启矩形train/test 默认训练集关闭 验证集开启 :param rank: 多卡训练时的进程编号 rank为进程编号 -1且gpu=1时不进行分布式 -1且多块gpu使用DataParallel模式 默认-1 The (global) rank of the current process. :param world_size: The total number of processes. Should be equal to the total number of devices (GPU) used for distributed training. :param workers: dataloader的numworks 加载数据时的cpu进程数 :param image_weights: 训练时是否根据图片样本真实框分布权重来选择图片 默认False :param quad: dataloader取数据时, 是否使用collate_fn4代替collate_fn 默认False :param prefix: 显示信息 一个标志,多为train/val,处理标签时保存cache文件会用到 """ # Make sure only the first process in DDP process the dataset first, and the following others can use the cache with torch_distributed_zero_first(rank): dataset = LoadImagesAndLabels(path, imgsz, batch_size, batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() dataloader = loader(dataset, batch_size=batch_size, num_workers=nw, sampler=sampler, pin_memory=True, collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) return dataloader, dataset
下面关注 utils/datasets.py::utils/datasets.py 代码
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''): # Make sure only the first process in DDP process the dataset first, and the following others can use the cache with torch_distributed_zero_first(rank): dataset = LoadImagesAndLabels(path, imgsz, batch_size, augment=augment, # augment images hyp=hyp, # augmentation hyperparameters rect=rect, # rectangular training cache_images=cache, single_cls=opt.single_cls, stride=int(stride), pad=pad, image_weights=image_weights, prefix=prefix)
下面是LoadImagesAndLabels 类的代码,是用来定义 dataset 代码:
class LoadImagesAndLabels(Dataset): # for training/testing def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''): self.img_size = img_size self.augment = augment self.hyp = hyp self.image_weights = image_weights self.rect = False if image_weights else rect self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training) self.mosaic_border = [-img_size // 2, -img_size // 2] self.stride = stride self.path = path #self.albumentations = Albumentations() if augment else None try: f = [] # image files for p in path if isinstance(path, list) else [path]: p = Path(p) # os-agnostic # PosixPath('../COCO2017/train2017.txt') if p.is_dir(): # dir f += glob.glob(str(p / '**' / '*.*'), recursive=True) # f = list(p.rglob('**/*.*')) # pathlib elif p.is_file(): # file 执行该步 with open(p, 'r') as t: # t: ['./images/train2017/000000109622.jpg', './images/train2017/000000160694.jpg', ...] t = t.read().strip().splitlines() parent = str(p.parent) + os.sep # '../COCO2017/' 获取父目录 f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path 每个元素是一个文件路径 将 t 每个图片路径转为 全路径 # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) else: raise Exception(f'{prefix}{p} does not exist') self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats]) # 如果图像后缀名在定义的9种以内,则把所有的 图像后缀名小写,最后排序 # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib assert self.img_files, f'{prefix}No images found' except Exception as e: raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}') # Check cache self.label_files = img2label_paths(self.img_files) # labels 获取 标注文件 list,对应于 上面的 img_files 一一对应 # cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels PosixPath('../COCO2017/train2017.cache') 保存一个缓存文件 cache_path = Path('/data/hyz/datasets/COCO2017').with_suffix('.cache') if cache_path.is_file(): cache, exists = torch.load(cache_path), True # load #if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed # cache, exists = self.cache_labels(cache_path, prefix), False # re-cache else: cache, exists = self.cache_labels(cache_path, prefix), False # cache # Display cache nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total if exists: d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' # Read cache cache.pop('hash') # remove hash cache.pop('version') # remove version labels, shapes, self.segments = zip(*cache.values()) self.labels = list(labels) self.shapes = np.array(shapes, dtype=np.float64) self.img_files = list(cache.keys()) # update self.label_files = img2label_paths(cache.keys()) # update if single_cls: for x in self.labels: x[:, 0] = 0 n = len(shapes) # number of images bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index nb = bi[-1] + 1 # number of batches self.batch = bi # batch index of image self.n = n self.indices = range(n) # Rectangular Training if self.rect: # Sort by aspect ratio s = self.shapes # wh ar = s[:, 1] / s[:, 0] # aspect ratio irect = ar.argsort() self.img_files = [self.img_files[i] for i in irect] # 图像路径list self.label_files = [self.label_files[i] for i in irect] # 标注路径list self.labels = [self.labels[i] for i in irect] # 对应的标注数据list self.shapes = s[irect] # wh ar = ar[irect] # Set training image shapes shapes = [[1, 1]] * nb for i in range(nb): ari = ar[bi == i] mini, maxi = ari.min(), ari.max() if maxi < 1: shapes[i] = [maxi, 1] elif mini > 1: shapes[i] = [1, 1 / mini] self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) self.imgs = [None] * n if cache_images: if cache_images == 'disk': self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy') self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files] self.im_cache_dir.mkdir(parents=True, exist_ok=True) gb = 0 # Gigabytes of cached images self.img_hw0, self.img_hw = [None] * n, [None] * n results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) pbar = tqdm(enumerate(results), total=n) for i, x in pbar: if cache_images == 'disk': if not self.img_npy[i].exists(): np.save(self.img_npy[i].as_posix(), x[0]) gb += self.img_npy[i].stat().st_size else: self.imgs[i], self.img_hw0[i], self.img_hw[i] = x gb += self.imgs[i].nbytes pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)' pbar.close() def cache_labels(self, path=Path('./labels.cache'), prefix=''): # 重写 # Cache dataset labels, check images and read shapes x = {} # dict nm, nf, ne, nc = 0, 0, 0, 0 # number missing(所有图片没有标注的数目和), found(找到的标注和), empty(虽然有标注文件,但是文件内啥都没写), duplicate(读取时候出现问题的样本数目) pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) # 产生这么个进度条,Scanning images: 0%| | 0/118287 [00:00<?, ?it/s] for i, (im_file, lb_file) in enumerate(pbar): # 循环每个样本,图像jpg-标注txt对 try: # verify images im = Image.open(im_file) # 验证图像是否可以打开 im.verify() # PIL verify # 检查文件完整性 shape = exif_size(im) # 获得 image size segments = [] # instance segments assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' assert im.format.lower() in img_formats, f'invalid image format {im.format}' # verify labels if os.path.isfile(lb_file): nf += 1 # label found with open(lb_file, 'r') as f: l = [x.split() for x in f.read().strip().splitlines()] # 把标注txt 文件的每行(一个标注)都读取出来组成list if any([len(x) > 8 for x in l]): # is segment 如果长度大于8那么该标注是分割 classes = np.array([x[0] for x in l], dtype=np.float32) # 标注的第一列代表类别,是一个字符串类型的数字, 如 '45', 这里组成当前文件的类别list:如 [45.0, 45.0, 50.0, 45.0, 49.0, 49.0, 49.0, 49.0] segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # 除了第一列,后面每两个数是一个标注的坐标,把每个实例分割框的每个点坐标 reshape 下 l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) 如(8,5) l = np.array(l, dtype=np.float32) if len(l): assert l.shape[1] == 5, 'labels require 5 columns each' # 即 cls,xywh assert (l >= 0).all(), 'negative labels' # 所有值都 >= 0 assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' # bbox 坐标不能在 图像外 assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' # 标注里面有重复的框 else: ne += 1 # label empty l = np.zeros((0, 5), dtype=np.float32) else: nm += 1 # label missing l = np.zeros((0, 5), dtype=np.float32) x[im_file] = [l, shape, segments] # x是一个dict,key 为 图像path,value:该图像的标注(如 8,5), 图像的宽高,分割的坐标 except Exception as e: nc += 1 print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \ f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" # 更新进度条 pbar.close() if nf == 0: print(f'{prefix}WARNING: No labels found in {path}. See {help_url}') x['hash'] = get_hash(self.label_files + self.img_files) x['results'] = nf, nm, ne, nc, i + 1 # 统计的数目 x['version'] = 0.1 # cache version torch.save(x, path) # save for next time logging.info(f'{prefix}New cache created: {path}') return x
如果我们要用自己的格式的数据来训练 YOLOv7 那么就需要修改该部分。
那么如何修改该部分呢?下面是我针对 Interhand 数据集所做的修改。
cache_path = Path('/datasets/' + Path(self.path).stem).with_suffix('.cache') if cache_path.is_file(): cache, exists = torch.load(cache_path), True # load else: cache, exists = self.interhand_cache_labels(cache_path, prefix), False # cache
下面是修改的加载函数,该函数会把 label 缓存到指定目录。
def interhand_cache_labels(self, path=Path('./labels.cache'), prefix=''): # 重写 # Cache dataset labels, check images and read shapes db = COCO(self.path) hand_cls = {'right': 0, 'left': 1, 'interacting': 2} in_img_path = "/data/InterHand2.6M/images/InterHand2.6M_5fps_batch1/images" x={} segments = [] # instance segments nm, nf, ne, nc = 0, 0, 0, 0 pbar = tqdm(db.dataset.items(), desc='Scanning images', total=len(db.dataset.items())) for i, (key, value) in enumerate(pbar): try: im_file = os.path.join(in_img_path, key) # verify images im = Image.open(im_file) # 验证图像是否可以打开 im.verify() # PIL verify # 检查文件完整性 shape = exif_size(im) # 获得 image size assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' assert im.format.lower() in img_formats, f'invalid image format {im.format}' if value['bbox']!=[]: nf += 1 # label found classes = np.array(hand_cls[value['hand_type']], dtype=np.float32).reshape(-1, 1) bbox = np.array(value['bbox'], dtype=np.float32).reshape(-1, 4) # 左上,w,h # original_img = cv2.imread(im_file) # bbx = value['bbox'] # x1 = int(bbx[0]) # y1 = int(bbx[1]) # x2 = int(bbx[0] + bbx[2]) # y2 = int(bbx[1] + bbx[3]) # temp_image = cv2.rectangle(original_img, (x1, y1), (x2, y2), (0, 0, 255), 2) # (左上,右下) # cv2.imwrite('./test.jpg', cv2.cvtColor(temp_image, cv2.COLOR_BGR2RGB)) # cv2 save l = np.concatenate((classes, xywh2cxcywh(bbox, shape)), 1) l = np.array(l, dtype=np.float32) # original_img = cv2.imread(im_file) # bbx = xywh2cxcywh(bbox, shape) # bbx = xywh2xyxy(bbx) # x1 = int(bbx[:, 0]*shape[0]) # y1 = int(bbx[:, 1]*shape[1]) # x2 = int(bbx[:, 2]*shape[0]) # y2 = int(bbx[:, 3]*shape[1]) # temp_image = cv2.rectangle(original_img, (x1, y1), (x2, y2), (0, 0, 255), 2) # (左上,右下) # cv2.imwrite('./test.jpg', cv2.cvtColor(temp_image, cv2.COLOR_BGR2RGB)) # cv2 save if len(l): assert l.shape[1] == 5, 'labels require 5 columns each' # 即 cls,xywh assert (l >= 0).all(), 'negative labels' # 所有值都 >= 0 assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' # bbox 坐标不能在 图像外 assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' # 标注里面有重复的框 else: ne += 1 # label empty l = np.zeros((0, 5), dtype=np.float32) else: nm += 1 # label missing l = np.zeros((0, 5), dtype=np.float32) x[im_file] = [l, shape, segments] except Exception as e: nc += 1 print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') pbar.desc = f"{prefix}Scanning '{in_img_path}' images and labels... " \ f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" # 更新进度条 pbar.close() if nf == 0: print(f'{prefix}WARNING: No labels found in {path}. See {help_url}') # x['hash'] = get_hash(imgs_path_list) x['results'] = nf, nm, ne, nc, i + 1 # 统计的数目 x['version'] = 0.1 # cache version torch.save(x, path) # save for next time logging.info(f'{prefix}New cache created: {path}') return x
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现