YOLOv7 源码解读之数据读取

 


1. YOLOv7 代码组织结构


YOLOv7 代码结构
.
├── cfg(存放`yaml`格式定义的网络结构)
│ ├── baseline(用来比较的)
│ │ ├── r50-csp.yaml
│ │ ├── x50-csp.yaml
│ │ ├── yolor-csp-x.yaml
│ │ ├── yolor-csp.yaml
│ │ ├── yolor-d6.yaml
│ │ ├── yolor-e6.yaml
│ │ ├── yolor-p6.yaml
│ │ ├── yolor-w6.yaml
│ │ ├── yolov3-spp.yaml
│ │ ├── yolov3.yaml
│ │ └── yolov4-csp.yaml
│ ├── deploy(部署时候使用的)
│ │ ├── yolov7-d6.yaml
│ │ ├── yolov7-e6e.yaml
│ │ ├── yolov7-e6.yaml
│ │ ├── yolov7-tiny-silu.yaml
│ │ ├── yolov7-tiny.yaml
│ │ ├── yolov7-w6.yaml
│ │ ├── yolov7x.yaml
│ │ └── yolov7.yaml
│ └── training(训练时候使用的)
│ ├── yolov7-d6.yaml
│ ├── yolov7-e6e.yaml
│ ├── yolov7-e6.yaml
│ ├── yolov7-tiny.yaml
│ ├── yolov7-w6.yaml
│ ├── yolov7x.yaml
│ └── yolov7.yaml
├── data()
│ ├── coco.yaml(COCO 数据集信息)
│ ├── hyp.scratch.custom.yaml(这四个都是模型训练时候的超参数)
│ ├── hyp.scratch.p5.yaml
│ ├── hyp.scratch.p6.yaml
│ └── hyp.scratch.tiny.yaml
├── deploy(部署相关)
│ └── triton-inference-server
│ ├── boundingbox.py
│ ├── client.py
│ ├── data
│ │ ├── dog.jpg
│ │ └── dog_result.jpg
│ ├── labels.py
│ ├── processing.py
│ ├── README.md
│ └── render.py
├── detect.py(可直接运行的检测脚本)
├── export.py(可导出 TorchScript、CoreML、 TorchScript-Lite和ONNX)
├── figure(一些图片)
│ ├── horses_prediction.jpg
│ ├── mask.png
│ ├── performance.png
│ ├── pose.png
│ ├── tennis_caption.png
│ ├── tennis.jpg
│ ├── tennis_panoptic.png
│ └── tennis_semantic.jpg
├── hubconf.py(感觉暂时没啥用)
├── inference
│ └── images
│ ├── bus.jpg
│ ├── horses.jpg
│ ├── image1.jpg
│ ├── image2.jpg
│ ├── image3.jpg
│ └── zidane.jpg
├── LICENSE.md
├── models(**重点**,存放网络结构)
│ ├── common.py(一些网络中的组件)
│ ├── experimental.py(一些可以实验的组件)
│ ├── __init__.py
│ └── yolo.py(网络结构定义,包括yaml 解析)
├── README.md
├── requirements.txt
├── runs(模型运行时候的输出)
│ └── train
│ ├── yolov7
│ │ ├── events.out.tfevents.1661764925.ai-ai-dev-az3-01.13526.0
│ │ ├── hyp.yaml
│ │ ├── opt.yaml
│ │ └── weights
├── scripts(获得COCO 数据集)
│ └── get_coco.sh
├── test.py(测试模型指标)
├── tools(一些jupyter notebook 代码)
│ ├── compare_YOLOv7e6_vs_YOLOv5x6_half.ipynb
│ ├── compare_YOLOv7e6_vs_YOLOv5x6.ipynb
│ ├── compare_YOLOv7_vs_YOLOv5m6_half.ipynb
│ ├── compare_YOLOv7_vs_YOLOv5m6.ipynb
│ ├── compare_YOLOv7_vs_YOLOv5s6.ipynb
│ ├── instance.ipynb
│ ├── keypoint.ipynb
│ ├── reparameterization.ipynb
│ ├── visualization.ipynb
│ ├── YOLOv7CoreML.ipynb
│ ├── YOLOv7-Dynamic-Batch-ONNXRUNTIME.ipynb
│ ├── YOLOv7-Dynamic-Batch-TENSORRT.ipynb
│ ├── YOLOv7onnx.ipynb
│ └── YOLOv7trt.ipynb
├── train_aux.py(rain p6 odels)
├── train.py(train p5 models)
├── utils
│ ├── activations.py(定义了很多激活函数)
│ ├── add_nms.py
│ ├── autoanchor.py
│ ├── aws
│ │ ├── __init__.py
│ │ ├── mime.sh
│ │ ├── resume.py
│ │ └── userdata.sh
│ ├── datasets.py(**重点**,数据的读取和加载)
│ ├── general.py(一些通用函数)
│ ├── google_app_engine
│ │ ├── additional_requirements.txt
│ │ ├── app.yaml
│ │ └── Dockerfile
│ ├── google_utils.py
│ ├── __init__.py
│ ├── loss.py(**定义损失**)
│ ├── metrics.py(衡量指标)
│ ├── plots.py(画图)
│ ├── torch_utils.py(YOLOR PyTorch utils)
│ └── wandb_logging
│ ├── __init__.py
│ ├── log_dataset.py
│ └── wandb_utils.py
└── yolov7.pt(预训练模型)

2. 数据读取

本文的目的是调试 COCO2017的数据集,之前我写过YOLOv5 训练 VOC 数据集的代码说明,https://blog.csdn.net/hymn1993/article/details/123664708
本文重新解读一下,但是没啥区别。

2.1 代码解读

程序入口:train.py

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))

注:augment = True,rect 为 False:

parser.add_argument('--rect', action='store_true', help='rectangular training')

我们训练时候不指定该参数,所以rect 为 False。rect: 是否开启矩形train/test,默认训练集关闭 ,验证集开启,可以加速。self.rect=True时,self.batch_shapes记载每个batch的shape(同一个batch的图片shape相同)。

utils/datasets.py::create_dataloader函数定义:

def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
"""在train.py中被调用,用于生成 dataloader, dataset,testloader
自定义dataloader函数: 调用LoadImagesAndLabels获取数据集(包括数据增强) + 调用分布式采样器 DistributedSampler +
自定义InfiniteDataLoader 进行永久持续的采样数据
:param path: 图片数据加载路径 train/test 如: '../COCO2017/train2017.txt'
:param imgsz: train/test图片尺寸(数据增强后大小) 如:640
:param batch_size: batch size 大小 如 32
:param stride: 模型最大stride 如 32
:param opt.single_cls: 数据集是否是单类别 默认False
:param hyp: 超参列表dict 网络训练时的一些超参数,包括学习率等,这里主要用到里面一些关于数据增强(旋转、平移等)的系数 在命令行参数中传入 `--hyp` 来定义
:param augment: 是否要进行数据增强 训练时为 True
:param cache: 是否 cache_images False
:param pad: 设置矩形训练的shape时进行的填充 默认0.0
:param rect: 是否开启矩形train/test 默认训练集关闭 验证集开启
:param rank: 多卡训练时的进程编号 rank为进程编号 -1且gpu=1时不进行分布式 -1且多块gpu使用DataParallel模式 默认-1 The (global) rank of the current process.
:param world_size: The total number of processes. Should be equal to the total number of devices (GPU) used for distributed training.
:param workers: dataloader的numworks 加载数据时的cpu进程数
:param image_weights: 训练时是否根据图片样本真实框分布权重来选择图片 默认False
:param quad: dataloader取数据时, 是否使用collate_fn4代替collate_fn 默认False
:param prefix: 显示信息 一个标志,多为train/val,处理标签时保存cache文件会用到
"""
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
dataloader = loader(dataset,
batch_size=batch_size,
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
return dataloader, dataset

下面关注 utils/datasets.py::utils/datasets.py 代码

def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)

下面是LoadImagesAndLabels 类的代码,是用来定义 dataset 代码:

class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
self.image_weights = image_weights
self.rect = False if image_weights else rect
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
self.path = path
#self.albumentations = Albumentations() if augment else None
try:
f = [] # image files
for p in path if isinstance(path, list) else [path]:
p = Path(p) # os-agnostic # PosixPath('../COCO2017/train2017.txt')
if p.is_dir(): # dir
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
# f = list(p.rglob('**/*.*')) # pathlib
elif p.is_file(): # file 执行该步
with open(p, 'r') as t:
# t: ['./images/train2017/000000109622.jpg', './images/train2017/000000160694.jpg', ...]
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep # '../COCO2017/' 获取父目录
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path 每个元素是一个文件路径 将 t 每个图片路径转为 全路径
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else:
raise Exception(f'{prefix}{p} does not exist')
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats]) # 如果图像后缀名在定义的9种以内,则把所有的 图像后缀名小写,最后排序
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
assert self.img_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
# Check cache
self.label_files = img2label_paths(self.img_files) # labels 获取 标注文件 list,对应于 上面的 img_files 一一对应
# cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels PosixPath('../COCO2017/train2017.cache') 保存一个缓存文件
cache_path = Path('/data/hyz/datasets/COCO2017').with_suffix('.cache')
if cache_path.is_file():
cache, exists = torch.load(cache_path), True # load
#if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed
# cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
else:
cache, exists = self.cache_labels(cache_path, prefix), False # cache
# Display cache
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
if exists:
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
# Read cache
cache.pop('hash') # remove hash
cache.pop('version') # remove version
labels, shapes, self.segments = zip(*cache.values())
self.labels = list(labels)
self.shapes = np.array(shapes, dtype=np.float64)
self.img_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update
if single_cls:
for x in self.labels:
x[:, 0] = 0
n = len(shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image
self.n = n
self.indices = range(n)
# Rectangular Training
if self.rect:
# Sort by aspect ratio
s = self.shapes # wh
ar = s[:, 1] / s[:, 0] # aspect ratio
irect = ar.argsort()
self.img_files = [self.img_files[i] for i in irect] # 图像路径list
self.label_files = [self.label_files[i] for i in irect] # 标注路径list
self.labels = [self.labels[i] for i in irect] # 对应的标注数据list
self.shapes = s[irect] # wh
ar = ar[irect]
# Set training image shapes
shapes = [[1, 1]] * nb
for i in range(nb):
ari = ar[bi == i]
mini, maxi = ari.min(), ari.max()
if maxi < 1:
shapes[i] = [maxi, 1]
elif mini > 1:
shapes[i] = [1, 1 / mini]
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
self.imgs = [None] * n
if cache_images:
if cache_images == 'disk':
self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
self.im_cache_dir.mkdir(parents=True, exist_ok=True)
gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n)
for i, x in pbar:
if cache_images == 'disk':
if not self.img_npy[i].exists():
np.save(self.img_npy[i].as_posix(), x[0])
gb += self.img_npy[i].stat().st_size
else:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x
gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
pbar.close()
def cache_labels(self, path=Path('./labels.cache'), prefix=''): # 重写
# Cache dataset labels, check images and read shapes
x = {} # dict
nm, nf, ne, nc = 0, 0, 0, 0 # number missing(所有图片没有标注的数目和), found(找到的标注和), empty(虽然有标注文件,但是文件内啥都没写), duplicate(读取时候出现问题的样本数目)
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) # 产生这么个进度条,Scanning images: 0%| | 0/118287 [00:00<?, ?it/s]
for i, (im_file, lb_file) in enumerate(pbar): # 循环每个样本,图像jpg-标注txt对
try:
# verify images
im = Image.open(im_file) # 验证图像是否可以打开
im.verify() # PIL verify # 检查文件完整性
shape = exif_size(im) # 获得 image size
segments = [] # instance segments
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in img_formats, f'invalid image format {im.format}'
# verify labels
if os.path.isfile(lb_file):
nf += 1 # label found
with open(lb_file, 'r') as f:
l = [x.split() for x in f.read().strip().splitlines()] # 把标注txt 文件的每行(一个标注)都读取出来组成list
if any([len(x) > 8 for x in l]): # is segment 如果长度大于8那么该标注是分割
classes = np.array([x[0] for x in l], dtype=np.float32) # 标注的第一列代表类别,是一个字符串类型的数字, 如 '45', 这里组成当前文件的类别list:如 [45.0, 45.0, 50.0, 45.0, 49.0, 49.0, 49.0, 49.0]
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # 除了第一列,后面每两个数是一个标注的坐标,把每个实例分割框的每个点坐标 reshape 下
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) 如(8,5)
l = np.array(l, dtype=np.float32)
if len(l):
assert l.shape[1] == 5, 'labels require 5 columns each' # 即 cls,xywh
assert (l >= 0).all(), 'negative labels' # 所有值都 >= 0
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' # bbox 坐标不能在 图像外
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' # 标注里面有重复的框
else:
ne += 1 # label empty
l = np.zeros((0, 5), dtype=np.float32)
else:
nm += 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
x[im_file] = [l, shape, segments] # x是一个dict,key 为 图像path,value:该图像的标注(如 8,5), 图像的宽高,分割的坐标
except Exception as e:
nc += 1
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" # 更新进度条
pbar.close()
if nf == 0:
print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, i + 1 # 统计的数目
x['version'] = 0.1 # cache version
torch.save(x, path) # save for next time
logging.info(f'{prefix}New cache created: {path}')
return x

如果我们要用自己的格式的数据来训练 YOLOv7 那么就需要修改该部分。

那么如何修改该部分呢?下面是我针对 Interhand 数据集所做的修改。

cache_path = Path('/datasets/' + Path(self.path).stem).with_suffix('.cache')
if cache_path.is_file():
cache, exists = torch.load(cache_path), True # load
else:
cache, exists = self.interhand_cache_labels(cache_path, prefix), False # cache

下面是修改的加载函数,该函数会把 label 缓存到指定目录。

def interhand_cache_labels(self, path=Path('./labels.cache'), prefix=''): # 重写
# Cache dataset labels, check images and read shapes
db = COCO(self.path)
hand_cls = {'right': 0, 'left': 1, 'interacting': 2}
in_img_path = "/data/InterHand2.6M/images/InterHand2.6M_5fps_batch1/images"
x={}
segments = [] # instance segments
nm, nf, ne, nc = 0, 0, 0, 0
pbar = tqdm(db.dataset.items(), desc='Scanning images', total=len(db.dataset.items()))
for i, (key, value) in enumerate(pbar):
try:
im_file = os.path.join(in_img_path, key)
# verify images
im = Image.open(im_file) # 验证图像是否可以打开
im.verify() # PIL verify # 检查文件完整性
shape = exif_size(im) # 获得 image size
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in img_formats, f'invalid image format {im.format}'
if value['bbox']!=[]:
nf += 1 # label found
classes = np.array(hand_cls[value['hand_type']], dtype=np.float32).reshape(-1, 1)
bbox = np.array(value['bbox'], dtype=np.float32).reshape(-1, 4) # 左上,w,h
# original_img = cv2.imread(im_file)
# bbx = value['bbox']
# x1 = int(bbx[0])
# y1 = int(bbx[1])
# x2 = int(bbx[0] + bbx[2])
# y2 = int(bbx[1] + bbx[3])
# temp_image = cv2.rectangle(original_img, (x1, y1), (x2, y2), (0, 0, 255), 2) # (左上,右下)
# cv2.imwrite('./test.jpg', cv2.cvtColor(temp_image, cv2.COLOR_BGR2RGB)) # cv2 save
l = np.concatenate((classes, xywh2cxcywh(bbox, shape)), 1)
l = np.array(l, dtype=np.float32)
# original_img = cv2.imread(im_file)
# bbx = xywh2cxcywh(bbox, shape)
# bbx = xywh2xyxy(bbx)
# x1 = int(bbx[:, 0]*shape[0])
# y1 = int(bbx[:, 1]*shape[1])
# x2 = int(bbx[:, 2]*shape[0])
# y2 = int(bbx[:, 3]*shape[1])
# temp_image = cv2.rectangle(original_img, (x1, y1), (x2, y2), (0, 0, 255), 2) # (左上,右下)
# cv2.imwrite('./test.jpg', cv2.cvtColor(temp_image, cv2.COLOR_BGR2RGB)) # cv2 save
if len(l):
assert l.shape[1] == 5, 'labels require 5 columns each' # 即 cls,xywh
assert (l >= 0).all(), 'negative labels' # 所有值都 >= 0
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' # bbox 坐标不能在 图像外
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' # 标注里面有重复的框
else:
ne += 1 # label empty
l = np.zeros((0, 5), dtype=np.float32)
else:
nm += 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
x[im_file] = [l, shape, segments]
except Exception as e:
nc += 1
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
pbar.desc = f"{prefix}Scanning '{in_img_path}' images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" # 更新进度条
pbar.close()
if nf == 0:
print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
# x['hash'] = get_hash(imgs_path_list)
x['results'] = nf, nm, ne, nc, i + 1 # 统计的数目
x['version'] = 0.1 # cache version
torch.save(x, path) # save for next time
logging.info(f'{prefix}New cache created: {path}')
return x
posted @   Zenith_Hugh  阅读(4532)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现

喜欢请打赏

扫描二维码打赏

微信打赏

点击右上角即可分享
微信分享提示