Yolov8-源码解析-二十八-

Yolov8 源码解析(二十八)

.\yolov8\ultralytics\data\base.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import glob  # 导入用于获取文件路径的模块
import math  # 导入数学函数模块
import os  # 导入操作系统功能模块
import random  # 导入生成随机数的模块
from copy import deepcopy  # 导入深拷贝函数
from multiprocessing.pool import ThreadPool  # 导入多线程池模块
from pathlib import Path  # 导入处理路径的模块
from typing import Optional  # 导入类型提示模块

import cv2  # 导入OpenCV图像处理库
import numpy as np  # 导入NumPy数值计算库
import psutil  # 导入进程和系统信息获取模块
from torch.utils.data import Dataset  # 导入PyTorch数据集基类

from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS  # 导入自定义数据处理工具
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM  # 导入自定义工具函数


class BaseDataset(Dataset):
    """
    Base dataset class for loading and processing image data.

    Args:
        img_path (str): Path to the folder containing images.
        imgsz (int, optional): Image size. Defaults to 640.
        cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
        augment (bool, optional): If True, data augmentation is applied. Defaults to True.
        hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
        prefix (str, optional): Prefix to print in log messages. Defaults to ''.
        rect (bool, optional): If True, rectangular training is used. Defaults to False.
        batch_size (int, optional): Size of batches. Defaults to None.
        stride (int, optional): Stride. Defaults to 32.
        pad (float, optional): Padding. Defaults to 0.0.
        single_cls (bool, optional): If True, single class training is used. Defaults to False.
        classes (list): List of included classes. Default is None.
        fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).

    Attributes:
        im_files (list): List of image file paths.
        labels (list): List of label data dictionaries.
        ni (int): Number of images in the dataset.
        ims (list): List of loaded images.
        npy_files (list): List of numpy file paths.
        transforms (callable): Image transformation function.
    """

    def __init__(
        self,
        img_path,
        imgsz=640,
        cache=False,
        augment=True,
        hyp=DEFAULT_CFG,
        prefix="",
        rect=False,
        batch_size=16,
        stride=32,
        pad=0.5,
        single_cls=False,
        classes=None,
        fraction=1.0,
        ):
        # 初始化数据集对象,设置各种参数和属性
        """Initialize BaseDataset with given configuration and options."""
        # 调用父类初始化方法
        super().__init__()
        # 设置图片路径
        self.img_path = img_path
        # 图像大小
        self.imgsz = imgsz
        # 是否进行数据增强
        self.augment = augment
        # 是否单类别
        self.single_cls = single_cls
        # 数据集前缀
        self.prefix = prefix
        # 数据集采样比例
        self.fraction = fraction
        # 获取所有图像文件路径
        self.im_files = self.get_img_files(self.img_path)
        # 获取标签
        self.labels = self.get_labels()
        # 更新标签,根据是否单类别和指定的类别
        self.update_labels(include_class=classes)  # single_cls and include_class
        # 图像数量
        self.ni = len(self.labels)  # number of images
        # 是否使用矩形边界框
        self.rect = rect
        # 批处理大小
        self.batch_size = batch_size
        # 步长
        self.stride = stride
        # 填充
        self.pad = pad
        # 如果使用矩形边界框,确保指定了批处理大小
        if self.rect:
            assert self.batch_size is not None
            # 设置矩形边界框参数
            self.set_rectangle()

        # 用于马赛克图像的缓冲线程
        self.buffer = []  # buffer size = batch size
        # 最大缓冲长度,最小为图像数量、批处理大小的8倍、1000中的最小值(如果进行数据增强)
        self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0

        # 缓存图像(缓存选项包括 True, False, None, "ram", "disk")
        self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
        # 生成每个图像文件对应的 .npy 文件路径
        self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
        # 设置缓存选项
        self.cache = cache.lower() if isinstance(cache, str) else "ram" if cache is True else None
        # 如果缓存选项是 "ram" 并且内存中已存在缓存,或者缓存选项是 "disk",则进行图像缓存
        if (self.cache == "ram" and self.check_cache_ram()) or self.cache == "disk":
            self.cache_images()

        # 构建图像转换操作
        self.transforms = self.build_transforms(hyp=hyp)
    def get_img_files(self, img_path):
        """Read image files."""
        try:
            f = []  # image files列表,用于存储图像文件路径
            for p in img_path if isinstance(img_path, list) else [img_path]:
                p = Path(p)  # 将路径转换为Path对象,以保证在不同操作系统上的兼容性
                if p.is_dir():  # 如果是目录
                    f += glob.glob(str(p / "**" / "*.*"), recursive=True)
                    # 获取目录下所有文件的路径,并加入到f列表中
                    # 使用glob模块,支持递归查找
                    # 使用pathlib的方式:F = list(p.rglob('*.*'))  
                elif p.is_file():  # 如果是文件
                    with open(p) as t:
                        t = t.read().strip().splitlines()  # 读取文件内容,并按行分割
                        parent = str(p.parent) + os.sep
                        # 获取文件的父目录,并在每个文件路径前添加父目录路径,处理本地到全局路径的转换
                        f += [x.replace("./", parent) if x.startswith("./") else x for x in t]
                        # 将文件路径添加到f列表中,处理相对路径
                        # 使用pathlib的方式:F += [p.parent / x.lstrip(os.sep) for x in t]
                else:
                    raise FileNotFoundError(f"{self.prefix}{p} does not exist")
                    # 如果既不是文件也不是目录,则抛出文件不存在的异常
            im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
            # 对f列表中的文件路径进行筛选,保留符合图像格式的文件路径,并排序
            # 使用pathlib的方式:self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])
            assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
            # 如果im_files为空,则抛出断言错误,表示未找到任何图像文件
        except Exception as e:
            raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
            # 捕获所有异常,并抛出带有详细信息的文件加载错误异常
        if self.fraction < 1:
            im_files = im_files[: round(len(im_files) * self.fraction)]  # 保留数据集的一部分比例
            # 如果fraction小于1,则根据fraction保留im_files中的部分文件路径
        return im_files
        # 返回处理后的图像文件路径列表

    def update_labels(self, include_class: Optional[list]):
        """Update labels to include only these classes (optional)."""
        include_class_array = np.array(include_class).reshape(1, -1)
        # 将include_class转换为NumPy数组,并进行形状重塑
        for i in range(len(self.labels)):
            if include_class is not None:  # 如果include_class不为空
                cls = self.labels[i]["cls"]
                bboxes = self.labels[i]["bboxes"]
                segments = self.labels[i]["segments"]
                keypoints = self.labels[i]["keypoints"]
                j = (cls == include_class_array).any(1)
                # 找到标签中与include_class相匹配的类别索引
                self.labels[i]["cls"] = cls[j]  # 更新类别
                self.labels[i]["bboxes"] = bboxes[j]  # 更新边界框
                if segments:  # 如果存在分割信息
                    self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
                    # 更新分割信息,只保留与include_class匹配的分割
                if keypoints is not None:  # 如果存在关键点信息
                    self.labels[i]["keypoints"] = keypoints[j]  # 更新关键点信息
            if self.single_cls:  # 如果标签是单类别的
                self.labels[i]["cls"][:, 0] = 0  # 将所有类别标记为0
    def load_image(self, i, rect_mode=True):
        """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
        # 从数据集索引 'i' 加载一张图片,并返回原图和调整大小后的尺寸
        im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
        
        if im is None:  # not cached in RAM
            # 如果图像未被缓存在内存中
            if fn.exists():  # load npy
                # 如果存在对应的 *.npy 文件,则加载该文件
                try:
                    im = np.load(fn)
                except Exception as e:
                    # 捕获异常,警告并删除损坏的 *.npy 图像文件
                    LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")
                    Path(fn).unlink(missing_ok=True)
                    # 从原始图像文件加载图像(BGR格式)
                    im = cv2.imread(f)  # BGR
            else:  # read image
                # 否则,直接从原始图像文件中读取图像(BGR格式)
                im = cv2.imread(f)  # BGR
            
            # 如果未能成功加载图像,则抛出文件未找到异常
            if im is None:
                raise FileNotFoundError(f"Image Not Found {f}")

            h0, w0 = im.shape[:2]  # orig hw
            if rect_mode:  # resize long side to imgsz while maintaining aspect ratio
                # 如果矩形模式为真,则将长边调整到指定的imgsz大小,并保持纵横比
                r = self.imgsz / max(h0, w0)  # ratio
                if r != 1:  # if sizes are not equal
                    # 计算调整后的宽高,并进行插值缩放
                    w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
                    im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
            elif not (h0 == w0 == self.imgsz):  # resize by stretching image to square imgsz
                # 否则,将图像拉伸调整到正方形大小imgsz
                im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)

            # 如果进行数据增强训练,则将处理后的图像数据和原始、调整后的尺寸保存到缓冲区
            if self.augment:
                self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
                self.buffer.append(i)
                if 1 < len(self.buffer) >= self.max_buffer_length:  # prevent empty buffer
                    # 如果缓冲区长度超过最大长度限制,则弹出最旧的元素
                    j = self.buffer.pop(0)
                    if self.cache != "ram":
                        # 如果不是RAM缓存,则清空该位置的图像和尺寸数据
                        self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None

            # 返回加载的图像、原始尺寸和调整后的尺寸
            return im, (h0, w0), im.shape[:2]

        # 如果图像已缓存在内存中,则直接返回已缓存的图像及其原始和调整后的尺寸
        return self.ims[i], self.im_hw0[i], self.im_hw[i]

    def cache_images(self):
        """Cache images to memory or disk."""
        b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
        # 根据缓存选项选择不同的缓存函数和存储介质
        fcn, storage = (self.cache_images_to_disk, "Disk") if self.cache == "disk" else (self.load_image, "RAM")
        
        # 使用线程池处理图像缓存操作
        with ThreadPool(NUM_THREADS) as pool:
            # 并行加载图像或执行缓存操作
            results = pool.imap(fcn, range(self.ni))
            # 使用进度条显示缓存进度
            pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
            for i, x in pbar:
                if self.cache == "disk":
                    # 如果缓存到磁盘,则累加缓存的图像文件大小
                    b += self.npy_files[i].stat().st_size
                else:  # 'ram'
                    # 如果缓存到RAM,则直接将加载的图像和其尺寸保存到相应的位置
                    self.ims[i], self.im_hw0[i], self.im_hw[i] = x
                    b += self.ims[i].nbytes
                # 更新进度条描述信息,显示当前缓存的总量及存储介质
                pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {storage})"
            pbar.close()
    def cache_images_to_disk(self, i):
        """Saves an image as an *.npy file for faster loading."""
        f = self.npy_files[i]  # 获取第 i 个 *.npy 文件的路径
        if not f.exists():  # 如果该文件不存在
            np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False)  # 将对应图像保存为 *.npy 文件

    def check_cache_ram(self, safety_margin=0.5):
        """Check image caching requirements vs available memory."""
        b, gb = 0, 1 << 30  # 初始化缓存图像占用的字节数和每个 GB 的字节数
        n = min(self.ni, 30)  # 选取 self.ni 和 30 中较小的一个作为采样图片数目
        for _ in range(n):
            im = cv2.imread(random.choice(self.im_files))  # 随机选取一张图片进行读取
            ratio = self.imgsz / max(im.shape[0], im.shape[1])  # 计算图片尺寸与最大宽高之比
            b += im.nbytes * ratio**2  # 计算每张图片占用的内存字节数,并根据比率进行加权求和
        mem_required = b * self.ni / n * (1 + safety_margin)  # 计算需要缓存整个数据集所需的内存大小(GB)
        mem = psutil.virtual_memory()  # 获取系统内存信息
        success = mem_required < mem.available  # 判断是否有足够的内存来缓存数据集
        if not success:  # 如果内存不足
            self.cache = None  # 清空缓存
            LOGGER.info(
                f"{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images "
                f"with {int(safety_margin * 100)}% safety margin but only "
                f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, not caching images ⚠️"
            )  # 记录日志,显示缓存失败的原因和相关内存信息
        return success  # 返回是否成功缓存的布尔值

    def set_rectangle(self):
        """Sets the shape of bounding boxes for YOLO detections as rectangles."""
        bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int)  # 计算每张图片所属的批次索引
        nb = bi[-1] + 1  # 计算总批次数

        s = np.array([x.pop("shape") for x in self.labels])  # 提取标签中的形状信息(宽高)
        ar = s[:, 0] / s[:, 1]  # 计算宽高比
        irect = ar.argsort()  # 对宽高比进行排序的索引
        self.im_files = [self.im_files[i] for i in irect]  # 根据排序后的索引重新排列图像文件路径
        self.labels = [self.labels[i] for i in irect]  # 根据排序后的索引重新排列标签
        ar = ar[irect]  # 根据排序后的索引重新排列宽高比

        # 设置训练图像的形状
        shapes = [[1, 1]] * nb
        for i in range(nb):
            ari = ar[bi == i]  # 找出属于当前批次的所有图片的宽高比
            mini, maxi = ari.min(), ari.max()  # 计算当前批次内宽高比的最小值和最大值
            if maxi < 1:
                shapes[i] = [maxi, 1]  # 如果最大宽高比小于1,则设为最大宽度,高度为1
            elif mini > 1:
                shapes[i] = [1, 1 / mini]  # 如果最小宽高比大于1,则设为宽度1,高度为最小高度的倒数

        self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride  # 计算批次形状,保证整数倍的步长
        self.batch = bi  # 记录每张图像所属的批次索引

    def __getitem__(self, index):
        """Returns transformed label information for given index."""
        return self.transforms(self.get_image_and_label(index))  # 返回给定索引的图像和标签的转换信息
    def get_image_and_label(self, index):
        """Get and return label information from the dataset."""
        label = deepcopy(self.labels[index])  # 创建标签的深层副本,确保不影响原始数据 https://github.com/ultralytics/ultralytics/pull/1948
        label.pop("shape", None)  # 如果存在形状信息,从标签中移除,通常适用于矩形标注数据
        # 载入图像并将相关信息存入标签字典
        label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
        # 计算图像缩放比例,用于评估
        label["ratio_pad"] = (
            label["resized_shape"][0] / label["ori_shape"][0],
            label["resized_shape"][1] / label["ori_shape"][1],
        )
        if self.rect:
            # 如果使用矩形模式,添加批次对应的形状信息到标签中
            label["rect_shape"] = self.batch_shapes[self.batch[index]]
        # 更新标签信息并返回
        return self.update_labels_info(label)

    def __len__(self):
        """Returns the length of the labels list for the dataset."""
        # 返回数据集标签列表的长度
        return len(self.labels)

    def update_labels_info(self, label):
        """Custom your label format here."""
        # 自定义标签格式的方法,直接返回输入的标签
        return label

    def build_transforms(self, hyp=None):
        """
        Users can customize augmentations here.

        Example:
            ```py
            if self.augment:
                # Training transforms
                return Compose([])
            else:
                # Val transforms
                return Compose([])
            ```
        """
        # 用户可以在此处自定义数据增强操作,此处抛出未实现错误,鼓励用户进行定制
        raise NotImplementedError

    def get_labels(self):
        """
        Users can customize their own format here.

        Note:
            Ensure output is a dictionary with the following keys:
            ```py
            dict(
                im_file=im_file,
                shape=shape,  # format: (height, width)
                cls=cls,
                bboxes=bboxes, # xywh
                segments=segments,  # xy
                keypoints=keypoints, # xy
                normalized=True, # or False
                bbox_format="xyxy",  # or xywh, ltwh
            )
            ```
        """
        # 用户可以在此处自定义标签输出格式,此处抛出未实现错误,鼓励用户进行定制
        raise NotImplementedError

.\yolov8\ultralytics\data\build.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import os
import random
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from torch.utils.data import dataloader, distributed

# 导入自定义数据集类
from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
# 导入数据加载器
from ultralytics.data.loaders import (
    LOADERS,
    LoadImagesAndVideos,
    LoadPilAndNumpy,
    LoadScreenshots,
    LoadStreams,
    LoadTensor,
    SourceTypes,
    autocast_list,
)
# 导入数据相关的工具函数和常量
from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS
# 导入辅助工具
from ultralytics.utils import RANK, colorstr
# 导入检查函数
from ultralytics.utils.checks import check_file


class InfiniteDataLoader(dataloader.DataLoader):
    """
    Dataloader that reuses workers.

    Uses same syntax as vanilla DataLoader.
    """

    def __init__(self, *args, **kwargs):
        """Dataloader that infinitely recycles workers, inherits from DataLoader."""
        super().__init__(*args, **kwargs)
        # 使用 _RepeatSampler 来无限循环利用数据加载器的工作线程
        object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
        # 创建迭代器
        self.iterator = super().__iter__()

    def __len__(self):
        """Returns the length of the batch sampler's sampler."""
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        """Creates a sampler that repeats indefinitely."""
        for _ in range(len(self)):
            yield next(self.iterator)

    def reset(self):
        """
        Reset iterator.

        This is useful when we want to modify settings of dataset while training.
        """
        # 重置迭代器,允许在训练过程中修改数据集设置
        self.iterator = self._get_iterator()


class _RepeatSampler:
    """
    Sampler that repeats forever.

    Args:
        sampler (Dataset.sampler): The sampler to repeat.
    """

    def __init__(self, sampler):
        """Initializes an object that repeats a given sampler indefinitely."""
        self.sampler = sampler

    def __iter__(self):
        """Iterates over the 'sampler' and yields its contents."""
        while True:
            yield from iter(self.sampler)


def seed_worker(worker_id):  # noqa
    """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
    # 设置数据加载器的工作线程种子
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
    """Build YOLO Dataset."""
    # 根据 multi_modal 参数选择 YOLO 单模态或多模态数据集
    dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
    # 返回一个数据集对象,用于训练或推断
    return dataset(
        img_path=img_path,           # 图像路径
        imgsz=cfg.imgsz,             # 图像尺寸
        batch_size=batch,            # 批处理大小
        augment=mode == "train",     # 是否进行数据增强(训练模式下)
        hyp=cfg,                     # 训练超参数配置
        rect=cfg.rect or rect,       # 是否使用矩形批处理(从配置文件或参数中获取)
        cache=cfg.cache or None,     # 是否缓存数据(从配置文件或参数中获取)
        single_cls=cfg.single_cls or False,  # 是否单类别训练(从配置文件或参数中获取,默认为False)
        stride=int(stride),          # 步幅大小(转换为整数)
        pad=0.0 if mode == "train" else 0.5,  # 填充值(训练模式下为0.0,推断模式下为0.5)
        prefix=colorstr(f"{mode}: "),  # 日志前缀,包含模式信息
        task=cfg.task,               # 任务类型(从配置文件中获取)
        classes=cfg.classes,         # 类别列表(从配置文件中获取)
        data=data,                   # 数据集对象
        fraction=cfg.fraction if mode == "train" else 1.0,  # 数据集分数(训练模式下从配置文件获取,推断模式下为1.0)
    )
# 构建用于 YOLO 数据集的数据加载器
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
    """Build YOLO Dataset."""
    # 返回一个 GroundingDataset 对象,用于训练或验证
    return GroundingDataset(
        img_path=img_path,  # 图像文件路径
        json_file=json_file,  # 包含标注信息的 JSON 文件路径
        imgsz=cfg.imgsz,  # 图像尺寸
        batch_size=batch,  # 批处理大小
        augment=mode == "train",  # 是否进行数据增强
        hyp=cfg,  # 配置信息对象,可能需要通过 get_hyps_from_cfg 函数获取
        rect=cfg.rect or rect,  # 是否使用矩形批处理
        cache=cfg.cache or None,  # 是否使用缓存
        single_cls=cfg.single_cls or False,  # 是否为单类别检测
        stride=int(stride),  # 步长
        pad=0.0 if mode == "train" else 0.5,  # 边缘填充
        prefix=colorstr(f"{mode}: "),  # 输出前缀
        task=cfg.task,  # YOLO 的任务类型
        classes=cfg.classes,  # 类别信息
        fraction=cfg.fraction if mode == "train" else 1.0,  # 数据集的使用比例
    )


# 构建用于训练或验证集的 DataLoader
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
    """Return an InfiniteDataLoader or DataLoader for training or validation set."""
    # 限制批处理大小不超过数据集的大小
    batch = min(batch, len(dataset))
    nd = torch.cuda.device_count()  # CUDA 设备数量
    nw = min(os.cpu_count() // max(nd, 1), workers)  # 确定使用的工作线程数量
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
    generator = torch.Generator()
    generator.manual_seed(6148914691236517205 + RANK)  # 设置随机数生成器种子
    # 返回一个 InfiniteDataLoader 或 DataLoader 对象
    return InfiniteDataLoader(
        dataset=dataset,  # 数据集对象
        batch_size=batch,  # 批处理大小
        shuffle=shuffle and sampler is None,  # 是否打乱数据顺序
        num_workers=nw,  # 工作线程数量
        sampler=sampler,  # 分布式采样器
        pin_memory=PIN_MEMORY,  # 是否将数据保存在固定内存中
        collate_fn=getattr(dataset, "collate_fn", None),  # 数据集的整理函数
        worker_init_fn=seed_worker,  # 工作线程初始化函数
        generator=generator,  # 随机数生成器
    )


# 检查输入数据源的类型,并返回相应的标志值
def check_source(source):
    """Check source type and return corresponding flag values."""
    webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
    if isinstance(source, (str, int, Path)):  # 检查是否为字符串、整数或路径
        source = str(source)
        is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS)  # 检查是否为支持的图像或视频格式
        is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))  # 检查是否为 URL
        webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)  # 是否为摄像头
        screenshot = source.lower() == "screen"  # 是否为屏幕截图
        if is_url and is_file:
            source = check_file(source)  # 下载文件
    elif isinstance(source, LOADERS):  # 检查是否为特定加载器类型
        in_memory = True  # 是否在内存中
    elif isinstance(source, (list, tuple)):  # 检查是否为列表或元组
        source = autocast_list(source)  # 转换列表元素为 PIL 图像或 np 数组
        from_img = True  # 是否从图像获取
    elif isinstance(source, (Image.Image, np.ndarray)):  # 检查是否为 PIL 图像或 np 数组
        from_img = True  # 是否从图像获取
    elif isinstance(source, torch.Tensor):  # 检查是否为 PyTorch 张量
        tensor = True  # 是否为张量
    else:
        raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict")  # 抛出错误,不支持的图像类型

    return source, webcam, screenshot, from_img, in_memory, tensor  # 返回源数据及相关标志值


# 加载推断数据源,用于目标检测,并应用必要的转换
def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):
    """
    Loads an inference source for object detection and applies necessary transformations.
    """
    # 返回一个 InfiniteDataLoader 对象,用于推断数据源加载
    return InfiniteDataLoader(
        dataset=dataset,  # 数据集对象
        batch_size=batch,  # 批处理大小
        shuffle=shuffle and sampler is None,  # 是否打乱数据顺序
        num_workers=nw,  # 工作线程数量
        sampler=sampler,  # 分布式采样器
        pin_memory=PIN_MEMORY,  # 是否将数据保存在固定内存中
        collate_fn=getattr(dataset, "collate_fn", None),  # 数据集的整理函数
        worker_init_fn=seed_worker,  # 工作线程初始化函数
        generator=generator,  # 随机数生成器
    )
    Args:
        source (str, Path, Tensor, PIL.Image, np.ndarray): 接收推理输入的源数据类型,可以是文件路径、张量、图像对象等。
        batch (int, optional): 数据加载器的批大小。默认为1。
        vid_stride (int, optional): 视频源的帧间隔。默认为1。
        buffer (bool, optional): 决定流式帧是否缓存。默认为False。

    Returns:
        dataset (Dataset): 返回特定输入源的数据集对象。
    """
    # 检查输入源的类型并进行适配
    source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
    
    # 如果数据源在内存中,则使用其类型;否则根据源的不同选择源类型
    source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)

    # 数据加载器选择
    if tensor:
        # 如果输入源是张量,则加载张量数据集
        dataset = LoadTensor(source)
    elif in_memory:
        # 如果输入源在内存中,则直接使用该源作为数据集
        dataset = source
    elif stream:
        # 如果输入源是流式数据(视频流),则加载流数据集
        dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)
    elif screenshot:
        # 如果输入源是截图,则加载截图数据集
        dataset = LoadScreenshots(source)
    elif from_img:
        # 如果输入源是PIL图像或numpy数组,则加载对应数据集
        dataset = LoadPilAndNumpy(source)
    else:
        # 其他情况下(图片或视频文件),加载图片和视频数据集
        dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride)

    # 将源类型附加到数据集对象
    setattr(dataset, "source_type", source_type)

    # 返回创建的数据集对象
    return dataset

.\yolov8\ultralytics\data\converter.py

# 导入必要的库和模块
import json
from collections import defaultdict
from pathlib import Path

import cv2
import numpy as np

# 导入 Ultralytics 自定义的日志记录和进度条显示工具
from ultralytics.utils import LOGGER, TQDM
# 导入 Ultralytics 自定义的文件处理工具中的路径增量函数
from ultralytics.utils.files import increment_path

# 将 COCO 91 类别映射到 COCO 80 类别的函数
def coco91_to_coco80_class():
    """
    Converts 91-index COCO class IDs to 80-index COCO class IDs.

    Returns:
        (list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
            corresponding 91-index class ID.
    """
    return [
        0,
        1,
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
        10,
        None,
        11,
        12,
        13,
        14,
        15,
        16,
        17,
        18,
        19,
        20,
        21,
        22,
        23,
        None,
        24,
        25,
        None,
        None,
        26,
        27,
        28,
        29,
        30,
        31,
        32,
        33,
        34,
        35,
        36,
        37,
        38,
        39,
        None,
        40,
        41,
        42,
        43,
        44,
        45,
        46,
        47,
        48,
        49,
        50,
        51,
        52,
        53,
        54,
        55,
        56,
        57,
        58,
        59,
        None,
        60,
        None,
        None,
        61,
        None,
        62,
        63,
        64,
        65,
        66,
        67,
        68,
        69,
        70,
        71,
        72,
        None,
        73,
        74,
        75,
        76,
        77,
        78,
        79,
        None,
    ]


# 将 COCO 80 类别映射到 COCO 91 类别的函数
def coco80_to_coco91_class():
    """
    Converts 80-index (val2014) to 91-index (paper).
    For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.

    Example:
        ```python
        import numpy as np

        a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
        b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
        x1 = [list(a[i] == b).index(True) + 1 for i in range(80)]  # darknet to coco
        x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)]  # coco to darknet
        ```py
    """
    # 返回一个包含指定整数的列表
    return [
        1,    # 第一个整数
        2,    # 第二个整数
        3,    # 第三个整数
        4,    # 第四个整数
        5,    # 第五个整数
        6,    # 第六个整数
        7,    # 第七个整数
        8,    # 第八个整数
        9,    # 第九个整数
        10,   # 第十个整数
        11,   # 第十一个整数
        13,   # 第十二个整数(注意此处应为第十三个整数,实际上有一个数字被跳过)
        14,   # 第十四个整数
        15,   # 第十五个整数
        16,   # 第十六个整数
        17,   # 第十七个整数
        18,   # 第十八个整数
        19,   # 第十九个整数
        20,   # 第二十个整数
        21,   # 第二十一个整数
        22,   # 第二十二个整数
        23,   # 第二十三个整数
        24,   # 第二十四个整数
        25,   # 第二十五个整数
        27,   # 第二十六个整数
        28,   # 第二十七个整数
        31,   # 第二十八个整数
        32,   # 第二十九个整数
        33,   # 第三十个整数
        34,   # 第三十一个整数
        35,   # 第三十二个整数
        36,   # 第三十三个整数
        37,   # 第三十四个整数
        38,   # 第三十五个整数
        39,   # 第三十六个整数
        40,   # 第三十七个整数
        41,   # 第三十八个整数
        42,   # 第三十九个整数
        43,   # 第四十个整数
        44,   # 第四十一个整数
        46,   # 第四十二个整数
        47,   # 第四十三个整数
        48,   # 第四十四个整数
        49,   # 第四十五个整数
        50,   # 第四十六个整数
        51,   # 第四十七个整数
        52,   # 第四十八个整数
        53,   # 第四十九个整数
        54,   # 第五十个整数
        55,   # 第五十一个整数
        56,   # 第五十二个整数
        57,   # 第五十三个整数
        58,   # 第五十四个整数
        59,   # 第五十五个整数
        60,   # 第五十六个整数
        61,   # 第五十七个整数
        62,   # 第五十八个整数
        63,   # 第五十九个整数
        64,   # 第六十个整数
        65,   # 第六十一个整数
        67,   # 第六十二个整数
        70,   # 第六十三个整数
        72,   # 第六十四个整数
        73,   # 第六十五个整数
        74,   # 第六十六个整数
        75,   # 第六十七个整数
        76,   # 第六十八个整数
        77,   # 第六十九个整数
        78,   # 第七十个整数
        79,   # 第七十一个整数
        80,   # 第七十二个整数
        81,   # 第七十三个整数
        82,   # 第七十四个整数
        84,   # 第七十五个整数
        85,   # 第七十六个整数
        86,   # 第七十七个整数
        87,   # 第七十八个整数
        88,   # 第七十九个整数
        89,   # 第八十个整数
        90,   # 第八十一个整数
    ]
def convert_coco(
    labels_dir="../coco/annotations/",
    save_dir="coco_converted/",
    use_segments=False,
    use_keypoints=False,
    cls91to80=True,
    lvis=False,
):
    """
    Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.

    Args:
        labels_dir (str, optional): Path to directory containing COCO dataset annotation files.
        save_dir (str, optional): Path to directory to save results to.
        use_segments (bool, optional): Whether to include segmentation masks in the output.
        use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
        cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
        lvis (bool, optional): Whether to convert data in lvis dataset way.

    Example:
        ```python
        from ultralytics.data.converter import convert_coco

        convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
        convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
        ```py

    Output:
        Generates output files in the specified output directory.
    """

    # Create dataset directory
    save_dir = increment_path(save_dir)  # 如果保存目录已存在,则增加路径编号
    for p in save_dir / "labels", save_dir / "images":
        p.mkdir(parents=True, exist_ok=True)  # 创建目录

    # Convert classes
    coco80 = coco91_to_coco80_class()  # 转换 COCO 数据集的 91 类别到 80 类别

    # Import json
    LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")

def convert_dota_to_yolo_obb(dota_root_path: str):
    """
    Converts DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.

    The function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads the
    associated label from the original labels directory and writes new labels in YOLO OBB format to a new directory.

    Args:
        dota_root_path (str): The root directory path of the DOTA dataset.

    Example:
        ```python
        from ultralytics.data.converter import convert_dota_to_yolo_obb

        convert_dota_to_yolo_obb('path/to/DOTA')
        ```py

    Notes:
        The directory structure assumed for the DOTA dataset:

            - DOTA
                ├─ images
                │   ├─ train
                │   └─ val
                └─ labels
                    ├─ train_original
                    └─ val_original

        After execution, the function will organize the labels into:

            - DOTA
                └─ labels
                    ├─ train
                    └─ val
    """
    dota_root_path = Path(dota_root_path)

    # Class names to indices mapping
    # 定义一个类别映射字典,将字符串类别映射到整数编码
    class_mapping = {
        "plane": 0,
        "ship": 1,
        "storage-tank": 2,
        "baseball-diamond": 3,
        "tennis-court": 4,
        "basketball-court": 5,
        "ground-track-field": 6,
        "harbor": 7,
        "bridge": 8,
        "large-vehicle": 9,
        "small-vehicle": 10,
        "helicopter": 11,
        "roundabout": 12,
        "soccer-ball-field": 13,
        "swimming-pool": 14,
        "container-crane": 15,
        "airport": 16,
        "helipad": 17,
    }
    
    def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir):
        """将单个图片的DOTA标注转换为YOLO OBB格式,并保存到指定目录。"""
        # 构建原始标签文件路径和保存路径
        orig_label_path = orig_label_dir / f"{image_name}.txt"
        save_path = save_dir / f"{image_name}.txt"
    
        # 使用原始标签文件进行读取,保存转换后的标签
        with orig_label_path.open("r") as f, save_path.open("w") as g:
            lines = f.readlines()
            for line in lines:
                parts = line.strip().split()
                if len(parts) < 9:
                    continue
                # 提取类别名称并映射到整数编码
                class_name = parts[8]
                class_idx = class_mapping[class_name]
                # 提取坐标信息并进行归一化
                coords = [float(p) for p in parts[:8]]
                normalized_coords = [
                    coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)
                ]
                # 格式化坐标信息,保留小数点后六位
                formatted_coords = ["{:.6g}".format(coord) for coord in normalized_coords]
                # 写入转换后的标签信息到文件中
                g.write(f"{class_idx} {' '.join(formatted_coords)}\n")
    
    # 对训练集和验证集两个阶段进行循环处理
    for phase in ["train", "val"]:
        # 构建图片路径、原始标签路径和保存标签的路径
        image_dir = dota_root_path / "images" / phase
        orig_label_dir = dota_root_path / "labels" / f"{phase}_original"
        save_dir = dota_root_path / "labels" / phase
    
        # 如果保存标签的目录不存在,则创建
        save_dir.mkdir(parents=True, exist_ok=True)
    
        # 获取当前阶段图片的路径列表,并对每张图片进行处理
        image_paths = list(image_dir.iterdir())
        for image_path in TQDM(image_paths, desc=f"Processing {phase} images"):
            # 如果图片不是PNG格式则跳过
            if image_path.suffix != ".png":
                continue
            # 获取图片名称(不含扩展名)、读取图片并获取其高度和宽度
            image_name_without_ext = image_path.stem
            img = cv2.imread(str(image_path))
            h, w = img.shape[:2]
            # 调用函数将标签进行转换并保存到指定目录
            convert_label(image_name_without_ext, w, h, orig_label_dir, save_dir)
# 将 YOLO 格式的边界框数据转换为分割数据或方向边界框(OBB)数据
# 生成分割数据时可能使用 SAM 自动标注器
def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"):
    # 读取 SAM 模型的路径
    """
    Args:
        im_dir (str): 图像文件夹的路径,包含待处理的图像
        save_dir (str, optional): 结果保存的文件夹路径,默认为 None
        sam_model (str, optional): SAM 自动标注器的模型文件名,默认为 "sam_b.pt"

    Returns:
        s (List[np.ndarray]): 连接后的分割数据列表,每个元素为 NumPy 数组
    """
    Args:
        im_dir (str | Path): 要转换的图像目录的路径。
        save_dir (str | Path): 生成标签的保存路径,如果为None,则保存到与im_dir同级的`labels-segment`目录中。默认为None。
        sam_model (str): 用于中间分割数据的分割模型;可选参数。

    Notes:
        数据集假设的输入目录结构:

            - im_dir
                ├─ 001.jpg
                ├─ ..
                └─ NNN.jpg
            - labels
                ├─ 001.txt
                ├─ ..
                └─ NNN.txt
    """
    from tqdm import tqdm  # 导入进度条库tqdm

    from ultralytics import SAM  # 导入分割模型SAM
    from ultralytics.data import YOLODataset  # 导入YOLO数据集
    from ultralytics.utils import LOGGER  # 导入日志记录器
    from ultralytics.utils.ops import xywh2xyxy  # 导入辅助操作函数xywh2xyxy

    # NOTE: add placeholder to pass class index check
    dataset = YOLODataset(im_dir, data=dict(names=list(range(1000))))  # 创建YOLO数据集对象,传入图像目录和类名列表
    if len(dataset.labels[0]["segments"]) > 0:  # 如果存在分割数据
        LOGGER.info("Segmentation labels detected, no need to generate new ones!")  # 记录日志,表示检测到分割标签,无需生成新标签
        return  # 返回

    LOGGER.info("Detection labels detected, generating segment labels by SAM model!")  # 记录日志,表示检测到检测标签,将使用SAM模型生成分割标签
    sam_model = SAM(sam_model)  # 创建SAM模型对象
    for label in tqdm(dataset.labels, total=len(dataset.labels), desc="Generating segment labels"):  # 使用进度条遍历数据集标签
        h, w = label["shape"]  # 获取标签图像的高度和宽度
        boxes = label["bboxes"]  # 获取标签中的边界框信息
        if len(boxes) == 0:  # 如果边界框数量为0,则跳过空标签
            continue
        boxes[:, [0, 2]] *= w  # 将边界框的x坐标缩放到图像宽度上
        boxes[:, [1, 3]] *= h  # 将边界框的y坐标缩放到图像高度上
        im = cv2.imread(label["im_file"])  # 读取标签对应的图像
        sam_results = sam_model(im, bboxes=xywh2xyxy(boxes), verbose=False, save=False)  # 使用SAM模型进行分割,获取分割结果
        label["segments"] = sam_results[0].masks.xyn  # 将分割结果存储在标签数据中的segments字段

    save_dir = Path(save_dir) if save_dir else Path(im_dir).parent / "labels-segment"  # 确定保存目录路径
    save_dir.mkdir(parents=True, exist_ok=True)  # 创建保存目录,如果不存在则创建

    for label in dataset.labels:  # 遍历数据集中的每个标签
        texts = []  # 存储要写入文件的文本列表
        lb_name = Path(label["im_file"]).with_suffix(".txt").name  # 获取标签文件的名称
        txt_file = save_dir / lb_name  # 确定要保存的文本文件路径
        cls = label["cls"]  # 获取标签的类别信息
        for i, s in enumerate(label["segments"]):  # 遍历每个分割标签
            line = (int(cls[i]), *s.reshape(-1))  # 构造要写入文件的一行文本内容
            texts.append(("%g " * len(line)).rstrip() % line)  # 将文本内容格式化并添加到文本列表中
        if texts:  # 如果存在文本内容
            with open(txt_file, "a") as f:  # 打开文件,追加写入模式
                f.writelines(text + "\n" for text in texts)  # 将文本列表中的内容逐行写入文件
    LOGGER.info(f"Generated segment labels saved in {save_dir}")  # 记录日志,表示生成的分割标签已保存在指定目录中

.\yolov8\ultralytics\data\dataset.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# 导入必要的模块和库
import contextlib
import json
from collections import defaultdict
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path

import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import ConcatDataset

# 导入 Ultralytics 自定义的工具函数和类
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from ultralytics.utils.torch_utils import TORCHVISION_0_18

# 导入数据增强相关模块
from .augment import (
    Compose,
    Format,
    Instances,
    LetterBox,
    RandomLoadText,
    classify_augmentations,
    classify_transforms,
    v8_transforms,
)
# 导入基础数据集类和工具函数
from .base import BaseDataset
from .utils import (
    HELP_URL,
    LOGGER,
    get_hash,
    img2label_paths,
    load_dataset_cache_file,
    save_dataset_cache_file,
    verify_image,
    verify_image_label,
)

# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
# 数据集缓存版本号
DATASET_CACHE_VERSION = "1.0.3"

# YOLODataset 类,用于加载 YOLO 格式的对象检测和/或分割标签数据集
class YOLODataset(BaseDataset):
    """
    Dataset class for loading object detection and/or segmentation labels in YOLO format.

    Args:
        data (dict, optional): A dataset YAML dictionary. Defaults to None.
        task (str): An explicit arg to point current task, Defaults to 'detect'.

    Returns:
        (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
    """

    # 初始化方法,设置数据集类型和任务类型
    def __init__(self, *args, data=None, task="detect", **kwargs):
        """Initializes the YOLODataset with optional configurations for segments and keypoints."""
        # 根据任务类型设置是否使用分割标签、关键点标签或旋转矩形标签
        self.use_segments = task == "segment"
        self.use_keypoints = task == "pose"
        self.use_obb = task == "obb"
        self.data = data
        # 断言不能同时使用分割标签和关键点标签
        assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
        # 调用父类 BaseDataset 的初始化方法
        super().__init__(*args, **kwargs)
    def cache_labels(self, path=Path("./labels.cache")):
        """
        Cache dataset labels, check images and read shapes.

        Args:
            path (Path): Path where to save the cache file. Default is Path('./labels.cache').

        Returns:
            (dict): labels.
        """
        # 初始化空字典用于存储标签数据
        x = {"labels": []}
        # 初始化计数器和消息列表
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
        # 构建描述信息字符串,表示正在扫描路径下的文件
        desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
        # 获取图像文件总数
        total = len(self.im_files)
        # 从数据中获取关键点形状信息
        nkpt, ndim = self.data.get("kpt_shape", (0, 0))
        # 如果使用关键点信息且关键点数量或维度不正确,抛出异常
        if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
            raise ValueError(
                "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
                "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
            )
        # 使用线程池处理图像验证任务
        with ThreadPool(NUM_THREADS) as pool:
            # 并行处理图像验证任务,获取验证结果
            results = pool.imap(
                func=verify_image_label,
                iterable=zip(
                    self.im_files,
                    self.label_files,
                    repeat(self.prefix),
                    repeat(self.use_keypoints),
                    repeat(len(self.data["names"])),
                    repeat(nkpt),
                    repeat(ndim),
                ),
            )
            # 初始化进度条对象
            pbar = TQDM(results, desc=desc, total=total)
            # 遍历进度条以显示验证进度
            for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
                # 更新计数器
                nm += nm_f
                nf += nf_f
                ne += ne_f
                nc += nc_f
                # 如果图像文件存在,则添加标签信息到x["labels"]中
                if im_file:
                    x["labels"].append(
                        {
                            "im_file": im_file,
                            "shape": shape,
                            "cls": lb[:, 0:1],  # n, 1
                            "bboxes": lb[:, 1:],  # n, 4
                            "segments": segments,
                            "keypoints": keypoint,
                            "normalized": True,
                            "bbox_format": "xywh",
                        }
                    )
                # 如果有消息,则添加到消息列表中
                if msg:
                    msgs.append(msg)
                # 更新进度条描述信息
                pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
            # 关闭进度条
            pbar.close()

        # 如果有警告消息,则记录日志
        if msgs:
            LOGGER.info("\n".join(msgs))
        # 如果未找到标签,则记录警告日志
        if nf == 0:
            LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
        # 计算数据集文件的哈希值并存储在结果字典中
        x["hash"] = get_hash(self.label_files + self.im_files)
        # 将结果相关信息存储在结果字典中
        x["results"] = nf, nm, ne, nc, len(self.im_files)
        # 将警告消息列表存储在结果字典中
        x["msgs"] = msgs  # warnings
        # 保存数据集缓存文件
        save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
        # 返回结果字典
        return x
    def get_labels(self):
        """Returns dictionary of labels for YOLO training."""
        # 获取图像文件对应的标签文件路径字典
        self.label_files = img2label_paths(self.im_files)
        # 构建缓存文件路径,并尝试加载 *.cache 文件
        cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
        try:
            # 尝试加载数据集缓存文件
            cache, exists = load_dataset_cache_file(cache_path), True  # attempt to load a *.cache file
            # 检查缓存文件版本与哈希值是否匹配当前要求
            assert cache["version"] == DATASET_CACHE_VERSION  # matches current version
            assert cache["hash"] == get_hash(self.label_files + self.im_files)  # identical hash
        except (FileNotFoundError, AssertionError, AttributeError):
            # 加载失败时,重新生成标签缓存
            cache, exists = self.cache_labels(cache_path), False  # run cache ops

        # 显示缓存信息
        nf, nm, ne, nc, n = cache.pop("results")  # found, missing, empty, corrupt, total
        if exists and LOCAL_RANK in {-1, 0}:
            d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
            TQDM(None, desc=self.prefix + d, total=n, initial=n)  # display results
            if cache["msgs"]:
                LOGGER.info("\n".join(cache["msgs"]))  # display warnings

        # 读取缓存内容
        [cache.pop(k) for k in ("hash", "version", "msgs")]  # remove items
        labels = cache["labels"]
        if not labels:
            # 若缓存中无标签信息,则发出警告
            LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
        self.im_files = [lb["im_file"] for lb in labels]  # update im_files

        # 检查数据集是否仅含有框或者分段信息
        lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
        len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
        if len_segments and len_boxes != len_segments:
            # 若分段数与框数不相等,则发出警告,并移除所有分段信息
            LOGGER.warning(
                f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
                f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
                "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
            )
            for lb in labels:
                lb["segments"] = []
        if len_cls == 0:
            # 若标签数量为零,则发出警告
            LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
        return labels
    # 构建并追加变换操作到列表中
    def build_transforms(self, hyp=None):
        """Builds and appends transforms to the list."""
        # 如果启用数据增强
        if self.augment:
            # 设置混合和镶嵌的比例,如果未使用矩形模式则为0.0
            hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
            hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
            # 使用指定的版本和超参数构建变换
            transforms = v8_transforms(self, self.imgsz, hyp)
        else:
            # 否则,使用指定的图像尺寸创建 LetterBox 变换
            transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
        # 添加格式化变换到变换列表
        transforms.append(
            Format(
                bbox_format="xywh",
                normalize=True,
                return_mask=self.use_segments,
                return_keypoint=self.use_keypoints,
                return_obb=self.use_obb,
                batch_idx=True,
                mask_ratio=hyp.mask_ratio,
                mask_overlap=hyp.overlap_mask,
                bgr=hyp.bgr if self.augment else 0.0,  # 仅影响训练时的图像背景
            )
        )
        return transforms

    # 关闭镶嵌,复制粘贴和混合选项,并构建转换
    def close_mosaic(self, hyp):
        """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
        # 将镶嵌比例设置为0.0
        hyp.mosaic = 0.0
        # 保持与之前版本v8 close-mosaic相同的行为,复制粘贴比例设置为0.0
        hyp.copy_paste = 0.0
        # 保持与之前版本v8 close-mosaic相同的行为,混合比例设置为0.0
        hyp.mixup = 0.0
        # 使用给定超参数构建转换
        self.transforms = self.build_transforms(hyp)

    def update_labels_info(self, label):
        """
        Custom your label format here.

        Note:
            cls is not with bboxes now, classification and semantic segmentation need an independent cls label
            Can also support classification and semantic segmentation by adding or removing dict keys there.
        """
        # 弹出标签中的边界框信息
        bboxes = label.pop("bboxes")
        # 弹出标签中的分割信息,默认为空列表
        segments = label.pop("segments", [])
        # 弹出标签中的关键点信息,默认为None
        keypoints = label.pop("keypoints", None)
        # 弹出标签中的边界框格式信息
        bbox_format = label.pop("bbox_format")
        # 弹出标签中的归一化信息
        normalized = label.pop("normalized")

        # 如果使用方向框,则设置分割重新采样数为100,否则设置为1000
        segment_resamples = 100 if self.use_obb else 1000
        # 如果存在分割信息
        if len(segments) > 0:
            # 对分割信息进行重采样,返回重采样后的堆栈数组
            segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
        else:
            # 否则创建全零数组,形状为(0, 1000, 2)
            segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
        # 创建实例对象,包含边界框、分割、关键点等信息
        label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
        return label
    # 定义一个函数用于将数据样本整理成批次
    def collate_fn(batch):
        """Collates data samples into batches."""
        # 创建一个新的批次字典
        new_batch = {}
        # 获取批次中第一个样本的所有键
        keys = batch[0].keys()
        # 获取批次中所有样本的值,并转置成列表形式
        values = list(zip(*[list(b.values()) for b in batch]))
        # 遍历所有键值对
        for i, k in enumerate(keys):
            # 获取当前键对应的值列表
            value = values[i]
            # 如果键是 "img",则将值列表堆叠为张量
            if k == "img":
                value = torch.stack(value, 0)
            # 如果键在 {"masks", "keypoints", "bboxes", "cls", "segments", "obb"} 中,
            # 则将值列表连接为张量
            if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
                value = torch.cat(value, 0)
            # 将处理后的值赋给新的批次字典对应的键
            new_batch[k] = value
        # 将新的批次索引列表转换为列表形式
        new_batch["batch_idx"] = list(new_batch["batch_idx"])
        # 为每个批次索引添加目标图像的索引以供 build_targets() 使用
        for i in range(len(new_batch["batch_idx"])):
            new_batch["batch_idx"][i] += i  # add target image index for build_targets()
        # 将处理后的批次索引连接为张量
        new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
        # 返回整理好的新批次字典
        return new_batch
class YOLOMultiModalDataset(YOLODataset):
    """
    Dataset class for loading object detection and/or segmentation labels in YOLO format.

    Args:
        data (dict, optional): A dataset YAML dictionary. Defaults to None.
        task (str): An explicit arg to point current task, Defaults to 'detect'.

    Returns:
        (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
    """

    def __init__(self, *args, data=None, task="detect", **kwargs):
        """Initializes a dataset object for object detection tasks with optional specifications."""
        # 调用父类构造函数初始化对象
        super().__init__(*args, data=data, task=task, **kwargs)

    def update_labels_info(self, label):
        """Add texts information for multi-modal model training."""
        # 调用父类方法更新标签信息
        labels = super().update_labels_info(label)
        # NOTE: some categories are concatenated with its synonyms by `/`.
        # 将数据集中的类别名按照 `/` 分割成列表,添加到标签中
        labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
        return labels

    def build_transforms(self, hyp=None):
        """Enhances data transformations with optional text augmentation for multi-modal training."""
        # 调用父类方法构建数据转换列表
        transforms = super().build_transforms(hyp)
        if self.augment:
            # NOTE: hard-coded the args for now.
            # 如果开启数据增强,插入一个文本加载的转换操作
            transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
        return transforms


class GroundingDataset(YOLODataset):
    """Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""

    def __init__(self, *args, task="detect", json_file, **kwargs):
        """Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
        # 断言任务类型为 "detect"
        assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
        self.json_file = json_file
        # 调用父类构造函数初始化对象
        super().__init__(*args, task=task, data={}, **kwargs)

    def get_img_files(self, img_path):
        """The image files would be read in `get_labels` function, return empty list here."""
        # 返回空列表,因为图像文件在 `get_labels` 函数中读取
        return []
    def get_labels(self):
        """Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
        labels = []  # 初始化空列表用于存储标签数据
        LOGGER.info("Loading annotation file...")  # 记录日志,指示正在加载注释文件
        with open(self.json_file, "r") as f:
            annotations = json.load(f)  # 从 JSON 文件中加载注释数据
        images = {f'{x["id"]:d}': x for x in annotations["images"]}  # 创建图像字典,以图像ID为键
        img_to_anns = defaultdict(list)
        for ann in annotations["annotations"]:
            img_to_anns[ann["image_id"]].append(ann)  # 根据图像ID将注释分组到字典中
        for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
            img = images[f"{img_id:d}"]  # 获取当前图像的信息
            h, w, f = img["height"], img["width"], img["file_name"]  # 获取图像的高度、宽度和文件名
            im_file = Path(self.img_path) / f  # 构建图像文件的路径
            if not im_file.exists():
                continue  # 如果图像文件不存在,则跳过处理
            self.im_files.append(str(im_file))  # 将图像文件路径添加到实例变量中
            bboxes = []  # 初始化空列表用于存储边界框信息
            cat2id = {}  # 初始化空字典,用于存储类别到ID的映射关系
            texts = []  # 初始化空列表用于存储文本信息
            for ann in anns:
                if ann["iscrowd"]:
                    continue  # 如果注释标记为iscrowd,则跳过处理
                box = np.array(ann["bbox"], dtype=np.float32)  # 获取注释中的边界框信息并转换为numpy数组
                box[:2] += box[2:] / 2  # 将边界框坐标转换为中心点坐标
                box[[0, 2]] /= float(w)  # 归一化边界框的x坐标
                box[[1, 3]] /= float(h)  # 归一化边界框的y坐标
                if box[2] <= 0 or box[3] <= 0:
                    continue  # 如果边界框的宽度或高度小于等于零,则跳过处理

                cat_name = " ".join([img["caption"][t[0]:t[1]] for t in ann["tokens_positive"]])  # 从tokens_positive获取类别名称
                if cat_name not in cat2id:
                    cat2id[cat_name] = len(cat2id)  # 将类别名称映射到唯一的ID
                    texts.append([cat_name])  # 将类别名称添加到文本列表中
                cls = cat2id[cat_name]  # 获取类别的ID
                box = [cls] + box.tolist()  # 将类别ID与边界框信息合并
                if box not in bboxes:
                    bboxes.append(box)  # 将边界框信息添加到列表中
            lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)  # 构建边界框数组或者空数组
            labels.append(
                {
                    "im_file": im_file,
                    "shape": (h, w),
                    "cls": lb[:, 0:1],  # 提取类别信息,n行1列
                    "bboxes": lb[:, 1:],  # 提取边界框信息,n行4列
                    "normalized": True,
                    "bbox_format": "xywh",
                    "texts": texts,
                }
            )  # 将图像信息和处理后的标签数据添加到标签列表中
        return labels  # 返回所有图像的标签信息列表

    def build_transforms(self, hyp=None):
        """Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
        transforms = super().build_transforms(hyp)  # 调用父类方法,获取基本的数据增强列表
        if self.augment:
            # NOTE: hard-coded the args for now.
            transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))  # 在数据增强列表的倒数第二个位置插入文本加载的随机操作
        return transforms  # 返回配置后的数据增强列表
class YOLOConcatDataset(ConcatDataset):
    """
    Dataset as a concatenation of multiple datasets.

    This class is useful to assemble different existing datasets.
    """

    @staticmethod
    def collate_fn(batch):
        """Collates data samples into batches."""
        return YOLODataset.collate_fn(batch)



# TODO: support semantic segmentation
class SemanticDataset(BaseDataset):
    """
    Semantic Segmentation Dataset.

    This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
    from the BaseDataset class.

    Note:
        This class is currently a placeholder and needs to be populated with methods and attributes for supporting
        semantic segmentation tasks.
    """

    def __init__(self):
        """Initialize a SemanticDataset object."""
        super().__init__()

class ClassificationDataset:
    """
    Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
    augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
    learning models, with optional image transformations and caching mechanisms to speed up training.

    This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
    in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
    to ensure data integrity and consistency.

    Attributes:
        cache_ram (bool): Indicates if caching in RAM is enabled.
        cache_disk (bool): Indicates if caching on disk is enabled.
        samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
                        file (if caching on disk), and optionally the loaded image array (if caching in RAM).
        torch_transforms (callable): PyTorch transforms to be applied to the images.
    """

    def __getitem__(self, i):
        """Returns subset of data and targets corresponding to given indices."""
        f, j, fn, im = self.samples[i]  # filename, index, filename.with_suffix('.npy'), image
        if self.cache_ram:
            if im is None:  # Warning: two separate if statements required here, do not combine this with previous line
                im = self.samples[i][3] = cv2.imread(f)
        elif self.cache_disk:
            if not fn.exists():  # load npy
                np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
            im = np.load(fn)
        else:  # read image
            im = cv2.imread(f)  # BGR
        # Convert NumPy array to PIL image
        im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
        sample = self.torch_transforms(im)
        return {"img": sample, "cls": j}

    def __len__(self) -> int:
        """Return the total number of samples in the dataset."""
        return len(self.samples)
    def verify_images(self):
        """Verify all images in dataset."""
        # 构建描述信息,指定要扫描的根目录
        desc = f"{self.prefix}Scanning {self.root}..."
        # 根据根目录生成对应的缓存文件路径
        path = Path(self.root).with_suffix(".cache")  # *.cache file path
        
        # 尝试加载缓存文件,处理可能出现的文件未找到、断言错误和属性错误
        with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
            # 加载数据集缓存文件
            cache = load_dataset_cache_file(path)  # attempt to load a *.cache file
            # 断言缓存文件版本与当前版本匹配
            assert cache["version"] == DATASET_CACHE_VERSION  # matches current version
            # 断言缓存文件的哈希与数据集样本的哈希一致
            assert cache["hash"] == get_hash([x[0] for x in self.samples])  # identical hash
            # 解构缓存结果,包括发现的、丢失的、空的、损坏的样本数量以及样本列表
            nf, nc, n, samples = cache.pop("results")  # found, missing, empty, corrupt, total
            # 如果在主机的本地或者单个进程运行时,显示描述信息和进度条
            if LOCAL_RANK in {-1, 0}:
                d = f"{desc} {nf} images, {nc} corrupt"
                TQDM(None, desc=d, total=n, initial=n)
                # 如果存在警告消息,则记录日志显示
                if cache["msgs"]:
                    LOGGER.info("\n".join(cache["msgs"]))  # display warnings
            # 返回样本列表
            return samples
        
        # 如果未能检索到缓存文件,则执行扫描操作
        nf, nc, msgs, samples, x = 0, 0, [], [], {}
        # 使用线程池并发执行图像验证函数
        with ThreadPool(NUM_THREADS) as pool:
            results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
            # 创建进度条并显示扫描描述信息
            pbar = TQDM(results, desc=desc, total=len(self.samples))
            for sample, nf_f, nc_f, msg in pbar:
                # 如果图像未损坏,则将其添加到样本列表中
                if nf_f:
                    samples.append(sample)
                # 如果存在警告消息,则添加到消息列表中
                if msg:
                    msgs.append(msg)
                # 更新发现的和损坏的图像数量
                nf += nf_f
                nc += nc_f
                # 更新进度条的描述信息
                pbar.desc = f"{desc} {nf} images, {nc} corrupt"
            # 关闭进度条
            pbar.close()
        
        # 如果存在警告消息,则记录日志显示
        if msgs:
            LOGGER.info("\n".join(msgs))
        
        # 计算数据集样本的哈希值并保存相关信息到 x 字典
        x["hash"] = get_hash([x[0] for x in self.samples])
        x["results"] = nf, nc, len(samples), samples
        x["msgs"] = msgs  # warnings
        
        # 将数据集缓存信息保存到缓存文件中
        save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
        
        # 返回发现的样本列表
        return samples

.\yolov8\ultralytics\data\explorer\explorer.py

        data: Union[str, Path] = "coco128.yaml",
        model: str = "yolov8n.pt",
        uri: str = USER_CONFIG_DIR / "explorer",

初始化方法,接受数据配置文件路径或字符串,默认为"coco128.yaml";模型文件名,默认为"yolov8n.pt";URI路径,默认为用户配置目录下的"explorer"。


        self.data = Path(data)
        self.model = Path(model)
        self.uri = Path(uri)

将传入的数据路径、模型路径和URI路径转换为`Path`对象,并分别赋值给实例变量`self.data`、`self.model`和`self.uri`。


        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

根据当前系统是否支持CUDA,选择使用GPU(如果可用)或CPU,并将设备类型赋值给实例变量`self.device`。


        self.model = YOLO(self.model).to(self.device).eval()

使用`YOLO`类加载指定的YOLO模型文件,并将其移动到之前确定的设备(GPU或CPU),然后设置为评估模式(eval),覆盖之前定义的`self.model`。


        self.data = ExplorerDataset(self.data)

使用`ExplorerDataset`类加载指定的数据配置文件,并赋值给实例变量`self.data`,以供后续数据集探索和操作使用。


    def embed_images(self, images: List[Union[np.ndarray, str, Path]]) -> List[np.ndarray]:
        """Embeds a list of images into feature vectors using the initialized YOLO model."""

定义一个方法`embed_images`,接受一个包含图像的列表(可以是`np.ndarray`、字符串路径或`Path`对象),返回一个包含特征向量的`np.ndarray`列表。


        embeddings = []
        for image in tqdm(images, desc="Embedding images"):
            if isinstance(image, (str, Path)):
                image = cv2.imread(str(image))  # BGR
            if image is None:
                LOGGER.error(f"Image Not Found {image}")
                embeddings.append(None)
                continue
            if isinstance(image, np.ndarray):
                image = torch.from_numpy(image).to(self.device).float() / 255.0
            else:
                embeddings.append(None)
                continue
            if image.ndimension() == 3:
                image = image.unsqueeze(0)
            with torch.no_grad():
                features = self.model(image)[0].cpu().numpy()
            embeddings.append(features)
        return embeddings

遍历图像列表,对每张图像进行以下操作:如果图像是字符串路径或`Path`对象,使用OpenCV加载图像(格式为BGR);如果加载失败,记录错误并在嵌入列表中添加`None`;如果图像是`np.ndarray`,将其转换为`torch.Tensor`并移动到设备上,然后进行归一化处理;最后使用YOLO模型提取图像特征,将特征向量添加到嵌入列表中。


    def create_table(self, schema: dict) -> bool:
        """Creates a table in LanceDB using the provided schema."""

定义一个方法`create_table`,接受一个表示表结构的字典作为参数,返回布尔值表示是否成功创建表。


        success = False
        try:
            success = get_table_schema(self.uri, schema)
        except Exception as e:
            LOGGER.error(f"Error creating table: {e}")
        return success

尝试调用`get_table_schema`函数,使用提供的URI路径和表结构字典创建表格。如果出现异常,记录错误信息,并返回`False`;否则返回函数调用结果。


    def query_similarity(self, image: Union[np.ndarray, str, Path], threshold: float = 0.5) -> List[Tuple[str, float]]:
        """Queries LanceDB for images similar to the provided image, using YOLO features."""

定义一个方法`query_similarity`,接受一个图像(可以是`np.ndarray`、字符串路径或`Path`对象)和相似度阈值作为参数,返回一个包含(文件名,相似度得分)元组的列表。


        schema = get_sim_index_schema()

调用`get_sim_index_schema`函数,获取相似度索引的模式。


        results = []
        try:
            image_embed = self.embed_images([image])[0]

调用`embed_images`方法,将提供的图像转换为特征向量。


            if image_embed is None:
                return results

如果特征向量为空,直接返回空结果列表。


            query_result = prompt_sql_query(self.uri, schema, image_embed, threshold)

使用提供的URI路径、模式、特征向量和阈值,调用`prompt_sql_query`函数执行相似度查询。


            results = [(r[0], float(r[1])) for r in query_result]
        except Exception as e:
            LOGGER.error(f"Error querying similarity: {e}")
        return results

遍历查询结果,将文件名和相似度得分组成元组,并将它们添加到结果列表中。如果出现异常,记录错误信息,并返回空的结果列表。
    ) -> None:
        """初始化 Explorer 类,设置数据集路径、模型和数据库连接的 URI。"""
        # 注意 duckdb==0.10.0 的 bug https://github.com/ultralytics/ultralytics/pull/8181
        checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"])
        import lancedb

        # 建立与数据库的连接
        self.connection = lancedb.connect(uri)
        # 设定表格名称,使用数据路径和模型名称的小写形式
        self.table_name = f"{Path(data).name.lower()}_{model.lower()}"
        # 设定相似度索引的基础名称,用于重用表格并添加阈值和 top_k 参数
        self.sim_idx_base_name = (
            f"{self.table_name}_sim_idx".lower()
        )  # 使用这个名称并附加阈值和 top_k 以重用表格
        # 初始化 YOLO 模型
        self.model = YOLO(model)
        # 数据路径
        self.data = data  # None
        # 选择集合为空
        self.choice_set = None

        # 表格为空
        self.table = None
        # 进度为 0
        self.progress = 0

    def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
        """
        创建包含数据集中图像嵌入的 LanceDB 表格。如果表格已经存在,则会重用它。传入 force=True 来覆盖现有表格。

        Args:
            force (bool): 是否覆盖现有表格。默认为 False。
            split (str): 要使用的数据集拆分。默认为 'train'。

        Example:
            ```python
            exp = Explorer()
            exp.create_embeddings_table()
            ```py
        """
        # 如果表格已存在且不强制覆盖,则返回
        if self.table is not None and not force:
            LOGGER.info("表格已存在。正在重用。传入 force=True 来覆盖它。")
            return
        # 如果表格名称在连接的表格列表中且不强制覆盖,则重用表格
        if self.table_name in self.connection.table_names() and not force:
            LOGGER.info(f"表格 {self.table_name} 已存在。正在重用。传入 force=True 来覆盖它。")
            self.table = self.connection.open_table(self.table_name)
            self.progress = 1
            return
        # 如果数据为空,则抛出 ValueError
        if self.data is None:
            raise ValueError("必须提供数据以创建嵌入表格")

        # 检查数据集的详细信息
        data_info = check_det_dataset(self.data)
        # 如果拆分参数不在数据集信息中,则抛出 ValueError
        if split not in data_info:
            raise ValueError(
                f"数据集中找不到拆分 {split}。数据集中可用的键为 {list(data_info.keys())}"
            )

        # 获取选择集并确保其为列表形式
        choice_set = data_info[split]
        choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
        self.choice_set = choice_set
        # 创建 ExplorerDataset 实例
        dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)

        # 创建表格模式
        batch = dataset[0]
        # 获取嵌入向量的大小
        vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
        # 创建表格
        table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
        # 向表格添加数据
        table.add(
            self._yield_batches(
                dataset,
                data_info,
                self.model,
                exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
            )
        )

        self.table = table
    def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
        """Generates batches of data for embedding, excluding specified keys."""
        # 遍历数据集中的每个样本
        for i in tqdm(range(len(dataset))):
            # 更新进度条
            self.progress = float(i + 1) / len(dataset)
            # 获取当前样本数据
            batch = dataset[i]
            # 排除指定的键
            for k in exclude_keys:
                batch.pop(k, None)
            # 对批次数据进行清洗
            batch = sanitize_batch(batch, data_info)
            # 使用模型对图像文件进行嵌入
            batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
            # 生成包含当前批次的列表,并进行 yield
            yield [batch]

    def query(
        self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
    ) -> Any:  # pyarrow.Table
        """
        Query the table for similar images. Accepts a single image or a list of images.

        Args:
            imgs (str or list): Path to the image or a list of paths to the images.
            limit (int): Number of results to return.

        Returns:
            (pyarrow.Table): An arrow table containing the results. Supports converting to:
                - pandas dataframe: `result.to_pandas()`
                - dict of lists: `result.to_pydict()`

        Example:
            ```python
            exp = Explorer()
            exp.create_embeddings_table()
            similar = exp.query(imgs=['https://ultralytics.com/images/zidane.jpg'])
            ```py
        """
        # 检查表格是否已创建
        if self.table is None:
            raise ValueError("Table is not created. Please create the table first.")
        # 如果 imgs 是单个字符串,则转换为列表
        if isinstance(imgs, str):
            imgs = [imgs]
        # 断言 imgs 类型为列表
        assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
        # 使用模型嵌入图像数据
        embeds = self.model.embed(imgs)
        # 如果传入多张图像,则计算平均嵌入向量
        embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
        # 使用嵌入向量进行查询,并限制结果数量
        return self.table.search(embeds).limit(limit).to_arrow()

    def sql_query(
        self, query: str, return_type: str = "pandas"
    ):
        """
        Execute an SQL query on the embedded data.

        Args:
            query (str): SQL query string.
            return_type (str): Type of the return data. Default is "pandas".

        Returns:
            Depending on return_type:
                - "pandas": Returns a pandas dataframe.
                - "arrow": Returns a pyarrow Table.
                - "dict": Returns a dictionary.

        Example:
            ```python
            exp = Explorer()
            query_result = exp.sql_query("SELECT * FROM embeddings WHERE category='person'", return_type='arrow')
            ```py
        """
        # 执行 SQL 查询,并根据返回类型返回相应的数据结构
        if return_type == "pandas":
            return pd.read_sql_query(query, self.conn)
        elif return_type == "arrow":
            return pa.Table.from_pandas(pd.read_sql_query(query, self.conn))
        elif return_type == "dict":
            return pd.read_sql_query(query, self.conn).to_dict(orient='list')
        else:
            raise ValueError(f"Unsupported return_type: {return_type}. Choose from 'pandas', 'arrow', or 'dict'.")
    ) -> Union[Any, None]:  # pandas.DataFrame or pyarrow.Table
        """
        Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.

        Args:
            query (str): SQL query to run.
            return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.

        Returns:
            (pyarrow.Table): An arrow table containing the results.

        Example:
            ```python
            exp = Explorer()
            exp.create_embeddings_table()
            query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
            result = exp.sql_query(query)
            ```py
        """
        # Ensure the return_type is either 'pandas' or 'arrow'
        assert return_type in {
            "pandas",
            "arrow",
        }, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
        
        import duckdb
        
        # Raise an error if the table is not created
        if self.table is None:
            raise ValueError("Table is not created. Please create the table first.")

        # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
        # Convert the internal table representation to Arrow format
        table = self.table.to_arrow()  # noqa NOTE: Don't comment this. This line is used by DuckDB
        
        # Check if the query starts with correct SQL keywords
        if not query.startswith("SELECT") and not query.startswith("WHERE"):
            raise ValueError(
                f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE "
                f"clause. found {query}"
            )
        
        # If the query starts with WHERE, prepend it with SELECT * FROM 'table'
        if query.startswith("WHERE"):
            query = f"SELECT * FROM 'table' {query}"
        
        # Log the query being executed
        LOGGER.info(f"Running query: {query}")

        # Execute the SQL query using duckdb
        rs = duckdb.sql(query)
        
        # Return the result based on the specified return_type
        if return_type == "arrow":
            return rs.arrow()
        elif return_type == "pandas":
            return rs.df()

    def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
        """
        Plot the results of a SQL-Like query on the table.
        
        Args:
            query (str): SQL query to run.
            labels (bool): Whether to plot the labels or not.

        Returns:
            (PIL.Image): Image containing the plot.

        Example:
            ```python
            exp = Explorer()
            exp.create_embeddings_table()
            query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
            result = exp.plot_sql_query(query)
            ```py
        """
        # Execute the SQL query with return_type='arrow' to get the result as an Arrow table
        result = self.sql_query(query, return_type="arrow")
        
        # If no results are found, log and return None
        if len(result) == 0:
            LOGGER.info("No results found.")
            return None
        
        # Generate a plot based on the query result and return it as a PIL Image
        img = plot_query_result(result, plot_labels=labels)
        return Image.fromarray(img)

    def get_similar(
        self,
        img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
        idx: Union[int, List[int]] = None,
        limit: int = 25,
        return_type: str = "pandas",
    ) -> Any:  # pandas.DataFrame or pyarrow.Table
        """
        Query the table for similar images. Accepts a single image or a list of images.

        Args:
            img (str or list): Path to the image or a list of paths to the images.
            idx (int or list): Index of the image in the table or a list of indexes.
            limit (int): Number of results to return. Defaults to 25.
            return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.

        Returns:
            (pandas.DataFrame or pyarrow.Table): Depending on return_type, either a DataFrame or a Table.

        Example:
            ```python
            exp = Explorer()
            exp.create_embeddings_table()
            similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
            ```py
        """
        assert return_type in {"pandas", "arrow"}, f"Return type should be `pandas` or `arrow`, but got {return_type}"
        # Check if img argument is valid and normalize it
        img = self._check_imgs_or_idxs(img, idx)
        # Query for similar images using the normalized img argument
        similar = self.query(img, limit=limit)

        if return_type == "arrow":
            # Return the query result as a pyarrow.Table
            return similar
        elif return_type == "pandas":
            # Convert the query result to a pandas DataFrame and return
            return similar.to_pandas()

    def plot_similar(
        self,
        img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
        idx: Union[int, List[int]] = None,
        limit: int = 25,
        labels: bool = True,
    ) -> Image.Image:
        """
        Plot the similar images. Accepts images or indexes.

        Args:
            img (str or list): Path to the image or a list of paths to the images.
            idx (int or list): Index of the image in the table or a list of indexes.
            labels (bool): Whether to plot the labels or not.
            limit (int): Number of results to return. Defaults to 25.

        Returns:
            (PIL.Image): Image containing the plot.

        Example:
            ```python
            exp = Explorer()
            exp.create_embeddings_table()
            similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
            ```py
        """
        # Retrieve similar images data in arrow format
        similar = self.get_similar(img, idx, limit, return_type="arrow")
        # If no similar images found, log and return None
        if len(similar) == 0:
            LOGGER.info("No results found.")
            return None
        # Plot the query result and return as a PIL.Image
        img = plot_query_result(similar, plot_labels=labels)
        return Image.fromarray(img)
    def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any:  # pd.DataFrame
        """
        Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
        are max_dist or closer to the image in the embedding space at a given index.

        Args:
            max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
            top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit.
                           vector search. Defaults: None.
            force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.

        Returns:
            (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image,
                and columns include indices of similar images and their respective distances.

        Example:
            ```python
            exp = Explorer()
            exp.create_embeddings_table()
            sim_idx = exp.similarity_index()
            ```py
        """
        # 如果表不存在,则抛出值错误异常
        if self.table is None:
            raise ValueError("Table is not created. Please create the table first.")
        # 构建相似性索引表名,包括最大距离和top_k参数
        sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
        # 如果指定的相似性索引表名已经存在且不强制覆盖,则记录日志并返回现有表的 pandas 数据帧
        if sim_idx_table_name in self.connection.table_names() and not force:
            LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
            return self.connection.open_table(sim_idx_table_name).to_pandas()

        # 如果指定了top_k参数且不在0到1之间,则抛出值错误异常
        if top_k and not (1.0 >= top_k >= 0.0):
            raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
        # 如果max_dist小于0,则抛出值错误异常
        if max_dist < 0.0:
            raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")

        # 计算实际的top_k数量,确保不小于1
        top_k = int(top_k * len(self.table)) if top_k else len(self.table)
        top_k = max(top_k, 1)
        # 从表中提取特征向量和图像文件名
        features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
        im_files = features["im_file"]
        embeddings = features["vector"]

        # 创建相似性索引表,使用指定的表名和模式
        sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")

        def _yield_sim_idx():
            """Generates a dataframe with similarity indices and distances for images."""
            # 使用进度条遍历嵌入向量列表
            for i in tqdm(range(len(embeddings))):
                # 在表中搜索与当前嵌入向量最相似的top_k项,并限制距离小于等于max_dist的项
                sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
                # 生成包含相似性索引信息的列表
                yield [
                    {
                        "idx": i,
                        "im_file": im_files[i],
                        "count": len(sim_idx),
                        "sim_im_files": sim_idx["im_file"].tolist(),
                    }
                ]

        # 将相似性索引信息添加到相似性索引表中
        sim_table.add(_yield_sim_idx())
        # 更新对象的相似性索引属性
        self.sim_index = sim_table
        # 返回相似性索引表的 pandas 数据帧
        return sim_table.to_pandas()
    def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
        """
        Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
        max_dist or closer to the image in the embedding space at a given index.

        Args:
            max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
            top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
                running vector search. Defaults to 0.01.
            force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.

        Returns:
            (PIL.Image): Image containing the plot.

        Example:
            ```python
            exp = Explorer()
            exp.create_embeddings_table()

            similarity_idx_plot = exp.plot_similarity_index()
            similarity_idx_plot.show() # view image preview
            similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
            ```py
        """
        # Retrieve similarity index based on provided parameters
        sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
        
        # Extract counts of similar images from the similarity index
        sim_count = sim_idx["count"].tolist()
        sim_count = np.array(sim_count)

        # Generate indices for the bar plot
        indices = np.arange(len(sim_count))

        # Create the bar plot using matplotlib
        plt.bar(indices, sim_count)

        # Customize the plot with labels and title
        plt.xlabel("data idx")
        plt.ylabel("Count")
        plt.title("Similarity Count")
        
        # Save the plot to a PNG image in memory
        buffer = BytesIO()
        plt.savefig(buffer, format="png")
        buffer.seek(0)

        # Use Pillow to open the image from the buffer and return it
        return Image.fromarray(np.array(Image.open(buffer)))


    def _check_imgs_or_idxs(
        self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
    ) -> List[np.ndarray]:
        """Determines whether to fetch images or indexes based on provided arguments and returns image paths."""
        # Check if both img and idx are None, which is not allowed
        if img is None and idx is None:
            raise ValueError("Either img or idx must be provided.")
        
        # Check if both img and idx are provided, which is also not allowed
        if img is not None and idx is not None:
            raise ValueError("Only one of img or idx must be provided.")
        
        # If idx is provided, fetch corresponding image paths from the table
        if idx is not None:
            idx = idx if isinstance(idx, list) else [idx]
            img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]

        # Return a list of image paths as numpy arrays
        return img if isinstance(img, list) else [img]
    # 定义一个方法,用于向AI提出问题并获取结果
    def ask_ai(self, query):
        """
        Ask AI a question.

        Args:
            query (str): Question to ask.

        Returns:
            (pandas.DataFrame): A dataframe containing filtered results to the SQL query.

        Example:
            ```python
            exp = Explorer()
            exp.create_embeddings_table()
            answer = exp.ask_ai('Show images with 1 person and 2 dogs')
            ```py
        """
        # 使用提供的查询字符串调用prompt_sql_query函数,并获取结果
        result = prompt_sql_query(query)
        try:
            # 尝试使用结果调用sql_query方法,返回处理后的数据帧
            return self.sql_query(result)
        except Exception as e:
            # 如果出现异常,记录错误信息到日志,并返回None
            LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
            LOGGER.error(e)
            return None

    # 定义一个方法,用于可视化查询结果,但当前未实现任何功能
    def visualize(self, result):
        """
        Visualize the results of a query. TODO.

        Args:
            result (pyarrow.Table): Table containing the results of a query.
        """
        # 目前这个方法没有实现任何功能,因此pass

    # 定义一个方法,用于生成数据集的报告,但当前未实现任何功能
    def generate_report(self, result):
        """
        Generate a report of the dataset.

        TODO
        """
        # 目前这个方法没有实现任何功能,因此pass
posted @ 2024-09-05 11:58  绝不原创的飞龙  阅读(10)  评论(0编辑  收藏  举报