Yolov8-源码解析-二十九-

Yolov8 源码解析(二十九)

.\yolov8\ultralytics\data\explorer\gui\dash.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import sys
import time
from threading import Thread

from ultralytics import Explorer  # 导入Explorer类,用于数据探索
from ultralytics.utils import ROOT, SETTINGS  # 导入项目根目录和设置模块
from ultralytics.utils.checks import check_requirements  # 导入检查依赖的函数

check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.3"))  # 检查必要的Streamlit版本和扩展

import streamlit as st  # 导入Streamlit库
from streamlit_select import image_select  # 导入图像选择组件


def _get_explorer():
    """Initializes and returns an instance of the Explorer class."""
    # 初始化Explorer实例,使用当前会话状态中的数据集和模型
    exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
    
    # 创建一个线程,用于生成嵌入表格,接受force和split参数
    thread = Thread(
        target=exp.create_embeddings_table,
        kwargs={"force": st.session_state.get("force_recreate_embeddings"), "split": st.session_state.get("split")},
    )
    thread.start()  # 启动线程
    
    # 创建进度条,显示生成嵌入表格的进度
    progress_bar = st.progress(0, text="Creating embeddings table...")
    while exp.progress < 1:
        time.sleep(0.1)
        progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
    thread.join()  # 等待线程完成
    st.session_state["explorer"] = exp  # 将Explorer实例存储在会话状态中
    progress_bar.empty()  # 清空进度条


def init_explorer_form(data=None, model=None):
    """Initializes an Explorer instance and creates embeddings table with progress tracking."""
    if data is None:
        # 如果未提供数据集,则从配置文件夹中加载所有数据集的名称
        datasets = ROOT / "cfg" / "datasets"
        ds = [d.name for d in datasets.glob("*.yaml")]
    else:
        ds = [data]

    if model is None:
        # 如果未提供模型,则使用默认的YoloV8模型列表
        models = [
            "yolov8n.pt",
            "yolov8s.pt",
            "yolov8m.pt",
            "yolov8l.pt",
            "yolov8x.pt",
            "yolov8n-seg.pt",
            "yolov8s-seg.pt",
            "yolov8m-seg.pt",
            "yolov8l-seg.pt",
            "yolov8x-seg.pt",
            "yolov8n-pose.pt",
            "yolov8s-pose.pt",
            "yolov8m-pose.pt",
            "yolov8l-pose.pt",
            "yolov8x-pose.pt",
        ]
    else:
        models = [model]

    splits = ["train", "val", "test"]

    # 在Streamlit中创建表单,用于初始化Explorer实例
    with st.form(key="explorer_init_form"):
        col1, col2, col3 = st.columns(3)
        with col1:
            st.selectbox("Select dataset", ds, key="dataset")  # 数据集选择框
        with col2:
            st.selectbox("Select model", models, key="model")  # 模型选择框
        with col3:
            st.selectbox("Select split", splits, key="split")  # 数据集划分选择框
        st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")  # 复选框,用于强制重新生成嵌入表
        st.form_submit_button("Explore", on_click=_get_explorer)  # 提交按钮,点击后触发_get_explorer函数


def query_form():
    """Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
    # 创建一个表单,用于初始化Explorer实例,并选择数据集和模型
    with st.form("query_form"):
        col1, col2 = st.columns([0.8, 0.2])
        with col1:
            st.text_input(
                "Query",
                "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
                label_visibility="collapsed",
                key="query",
            )  # 文本输入框,用于输入查询条件
        with col2:
            st.form_submit_button("Query", on_click=run_sql_query)  # 查询按钮,点击后执行run_sql_query函数


def ai_query_form():
    # 此函数尚未实现,预留用于将来扩展
    pass
    """Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
    # 使用 Streamlit 创建一个表单,用于用户输入以初始化 Explorer,包括数据集和模型选择
    
    with st.form("ai_query_form"):
        # 创建一个名为 "ai_query_form" 的表单
    
        col1, col2 = st.columns([0.8, 0.2])
        # 在界面上创建两列,比例为 0.8 和 0.2
    
        with col1:
            # 在第一列显示以下内容
            st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
            # 创建一个文本输入框,用于输入查询内容,默认显示文本为 "Show images with 1 person and 1 dog",标签不可见,键值为 "ai_query"
    
        with col2:
            # 在第二列显示以下内容
            st.form_submit_button("Ask AI", on_click=run_ai_query)
            # 创建一个提交按钮,显示文本为 "Ask AI",点击按钮会触发名为 run_ai_query 的函数
# 初始化一个 Streamlit 表单,用于基于自定义输入进行 AI 图像查询
def find_similar_imgs(imgs):
    # 从会话状态中获取名为 "explorer" 的对象
    exp = st.session_state["explorer"]
    # 调用 explorer 对象的 get_similar 方法,使用 imgs 参数进行图像相似性查询,限制查询数量为会话状态中的 "limit",返回类型为 "arrow"
    similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
    # 从查询结果中获取图像文件路径列表
    paths = similar.to_pydict()["im_file"]
    # 将查询结果的图像文件路径列表存储在会话状态中的 "imgs" 键下
    st.session_state["imgs"] = paths
    # 将查询结果对象存储在会话状态中的 "res" 键下
    st.session_state["res"] = similar


# 初始化一个 Streamlit 表单,用于基于自定义输入进行 AI 图像查询
def similarity_form(selected_imgs):
    # 输出表单标题
    st.write("Similarity Search")
    # 创建名为 "similarity_form" 的表单
    with st.form("similarity_form"):
        # 将表单分成两列,比例为 1:1
        subcol1, subcol2 = st.columns([1, 1])
        with subcol1:
            # 在第一列中添加一个数字输入框,用于设置查询结果的限制数量,初始值为 25
            st.number_input(
                "limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
            )

        with subcol2:
            # 禁用按钮的状态取决于是否选择了至少一张图像
            disabled = not len(selected_imgs)
            # 显示当前选择的图像数量
            st.write("Selected: ", len(selected_imgs))
            # 添加提交按钮 "Search",点击时调用 find_similar_imgs 函数,传入 selected_imgs 参数
            st.form_submit_button(
                "Search",
                disabled=disabled,
                on_click=find_similar_imgs,
                args=(selected_imgs,),
            )
        # 如果未选择任何图像,则显示错误消息
        if disabled:
            st.error("Select at least one image to search.")


# 未注释代码段
# def persist_reset_form():
#    with st.form("persist_reset"):
#        col1, col2 = st.columns([1, 1])
#        with col1:
#            st.form_submit_button("Reset", on_click=reset)
#
#        with col2:
#            st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True))


# 执行 SQL 查询并返回结果
def run_sql_query():
    # 清除会话状态中的错误信息
    st.session_state["error"] = None
    # 获取会话状态中的查询字符串
    query = st.session_state.get("query")
    # 如果查询字符串非空
    if query.rstrip().lstrip():
        # 从会话状态中获取名为 "explorer" 的对象
        exp = st.session_state["explorer"]
        # 调用 explorer 对象的 sql_query 方法执行 SQL 查询,返回类型为 "arrow"
        res = exp.sql_query(query, return_type="arrow")
        # 将查询结果的图像文件路径列表存储在会话状态中的 "imgs" 键下
        st.session_state["imgs"] = res.to_pydict()["im_file"]
        # 将查询结果对象存储在会话状态中的 "res" 键下
        st.session_state["res"] = res


# 执行 AI 查询并更新会话状态中的查询结果
def run_ai_query():
    # 如果未设置 SETTINGS 中的 "openai_api_key",则设置错误信息并返回
    if not SETTINGS["openai_api_key"]:
        st.session_state["error"] = (
            'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
        )
        return
    # 导入 pandas 库,以便更快地导入 'import ultralytics'
    import pandas  # scope for faster 'import ultralytics'

    # 清除会话状态中的错误信息
    st.session_state["error"] = None
    # 获取会话状态中的 AI 查询字符串
    query = st.session_state.get("ai_query")
    # 如果查询字符串非空
    if query.rstrip().lstrip():
        # 从会话状态中获取名为 "explorer" 的对象
        exp = st.session_state["explorer"]
        # 调用 explorer 对象的 ask_ai 方法执行 AI 查询
        res = exp.ask_ai(query)
        # 如果返回的结果不是 pandas.DataFrame 或结果为空,则设置错误信息并返回
        if not isinstance(res, pandas.DataFrame) or res.empty:
            st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
            return
        # 将查询结果中的图像文件路径列表存储在会话状态中的 "imgs" 键下
        st.session_state["imgs"] = res["im_file"].to_list()
        # 将查询结果对象存储在会话状态中的 "res" 键下
        st.session_state["res"] = res


# 重置探索器的初始状态,清除会话变量
def reset_explorer():
    # 清除会话状态中的 "explorer" 对象
    st.session_state["explorer"] = None
    # 清除会话状态中的 "imgs" 键
    st.session_state["imgs"] = None
    # 清除会话状态中的错误信息
    st.session_state["error"] = None


# 未注释代码段
# def utralytics_explorer_docs_callback():
    """Resets the explorer to its initial state by clearing session variables."""
    # 使用 streamlit 库的 container 组件创建一个带边框的容器
    with st.container(border=True):
        # 在容器中显示图片,图片来源于指定的 URL,设置宽度为 100 像素
        st.image(
            "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
            width=100,
        )
        # 在容器中使用 Markdown 格式显示文本,文本包含 HTML 元素
        st.markdown(
            "<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
            unsafe_allow_html=True,
            help=None,
        )
        # 在容器中添加一个链接按钮,链接到 Ultralytics Explorer API 文档页面
        st.link_button("Ultralytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
def layout(data=None, model=None):
    """Resets explorer session variables and provides documentation with a link to API docs."""
    # 设置页面配置为宽布局,侧边栏初始状态为折叠
    st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
    # 在页面中心显示标题,支持HTML标记
    st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)

    # 如果会话状态中不存在"explorer"变量,初始化探索器表单并返回
    if st.session_state.get("explorer") is None:
        init_explorer_form(data, model)
        return

    # 显示返回到选择数据集的按钮,点击时调用reset_explorer函数
    st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
    
    # 获取会话状态中的"explorer"对象
    exp = st.session_state.get("explorer")
    
    # 创建两列布局,比例为0.75和0.25,列之间的间隙为"small"
    col1, col2 = st.columns([0.75, 0.25], gap="small")
    
    # 初始化一个空列表imgs,用于存储图像数据
    imgs = []
    
    # 如果会话状态中存在"error"变量,显示错误信息
    if st.session_state.get("error"):
        st.error(st.session_state["error"])
    
    # 如果会话状态中存在"imgs"变量,将imgs设置为该变量的值
    elif st.session_state.get("imgs"):
        imgs = st.session_state.get("imgs")
    
    # 否则,从exp对象的表中获取图像文件列表并存储到imgs中
    else:
        imgs = exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
        # 将结果表存储到会话状态的"res"变量中
        st.session_state["res"] = exp.table.to_arrow()
    
    # 计算总图像数量和已选择的图像数量,初始化selected_imgs为空列表
    total_imgs, selected_imgs = len(imgs), []
    with col1:
        # 列1的内容

        # 拆分子列,共5列
        subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
        
        with subcol1:
            # 在子列1中显示文本
            st.write("Max Images Displayed:")

        with subcol2:
            # 在子列2中获取用户输入的最大显示图片数量
            num = st.number_input(
                "Max Images Displayed",
                min_value=0,
                max_value=total_imgs,
                value=min(500, total_imgs),
                key="num_imgs_displayed",
                label_visibility="collapsed",
            )
        
        with subcol3:
            # 在子列3中显示文本
            st.write("Start Index:")

        with subcol4:
            # 在子列4中获取用户输入的起始索引
            start_idx = st.number_input(
                "Start Index",
                min_value=0,
                max_value=total_imgs,
                value=0,
                key="start_index",
                label_visibility="collapsed",
            )
        
        with subcol5:
            # 在子列5中创建一个重置按钮,并在点击时执行重置操作
            reset = st.button("Reset", use_container_width=False, key="reset")
            if reset:
                # 重置图像数据的会话状态
                st.session_state["imgs"] = None
                # 实验性重新运行应用以应用更改
                st.experimental_rerun()

        # 显示查询表单和AI查询表单
        query_form()
        ai_query_form()

        if total_imgs:
            # 初始化变量
            labels, boxes, masks, kpts, classes = None, None, None, None, None
            # 获取当前任务类型
            task = exp.model.task
            
            # 如果用户选择显示标签
            if st.session_state.get("display_labels"):
                # 从结果中获取标签、边界框、掩模、关键点和类别信息
                labels = st.session_state.get("res").to_pydict()["labels"][start_idx : start_idx + num]
                boxes = st.session_state.get("res").to_pydict()["bboxes"][start_idx : start_idx + num]
                masks = st.session_state.get("res").to_pydict()["masks"][start_idx : start_idx + num]
                kpts = st.session_state.get("res").to_pydict()["keypoints"][start_idx : start_idx + num]
                classes = st.session_state.get("res").to_pydict()["cls"][start_idx : start_idx + num]
            
            # 获取显示的图像
            imgs_displayed = imgs[start_idx : start_idx + num]
            
            # 显示选定的图像,包括相关信息
            selected_imgs = image_select(
                f"Total samples: {total_imgs}",
                images=imgs_displayed,
                use_container_width=False,
                labels=labels,
                classes=classes,
                bboxes=boxes,
                masks=masks if task == "segment" else None,
                kpts=kpts if task == "pose" else None,
            )

    with col2:
        # 在列2中显示相似性表单
        similarity_form(selected_imgs)
        
        # 显示一个复选框,控制是否显示标签
        st.checkbox("Labels", value=False, key="display_labels")
        
        # 调用用于用户行为分析的文档回调函数
        utralytics_explorer_docs_callback()
# 如果当前脚本作为主程序运行
if __name__ == "__main__":
    # 使用命令行参数构建一个字典,键值对为偶数索引的参数作为键,奇数索引的参数作为对应的值
    kwargs = dict(zip(sys.argv[1::2], sys.argv[2::2]))
    # 将构建的参数字典作为关键字参数传递给名为 layout 的函数
    layout(**kwargs)

.\yolov8\ultralytics\data\explorer\gui\__init__.py

# 项目标题或注释,指出这是关于Ultralytics YOLO的代码,可能是项目的注释或者简要说明
# 🚀 可能表示对项目进展或效果的乐观看法,也可能只是注入了一些幽默或个性化的元素
# AGPL-3.0 license 表示代码采用了 AGPL-3.0 许可证,用于说明代码的开源许可证类型

.\yolov8\ultralytics\data\explorer\utils.py

# 导入必要的模块和库
import getpass  # 导入获取用户信息的模块
from typing import List  # 引入列表类型提示

import cv2  # 导入OpenCV库,用于图像处理
import numpy as np  # 导入NumPy库,用于数值计算

# 导入Ultralytics项目中的数据增强、日志等工具
from ultralytics.data.augment import LetterBox
from ultralytics.utils import LOGGER as logger
from ultralytics.utils import SETTINGS
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.ops import xyxy2xywh
from ultralytics.utils.plotting import plot_images


def get_table_schema(vector_size):
    """提取并返回数据库表的模式。"""
    # 导入LanceModel和Vector类
    from lancedb.pydantic import LanceModel, Vector

    # 定义表模式Schema类
    class Schema(LanceModel):
        im_file: str  # 图像文件名
        labels: List[str]  # 标签列表
        cls: List[int]  # 类别列表
        bboxes: List[List[float]]  # 边界框列表
        masks: List[List[List[int]]]  # 掩模列表
        keypoints: List[List[List[float]]]  # 关键点列表
        vector: Vector(vector_size)  # 特征向量

    return Schema


def get_sim_index_schema():
    """返回具有指定向量大小的数据库表的LanceModel模式。"""
    # 导入LanceModel类
    from lancedb.pydantic import LanceModel

    # 定义模式Schema类
    class Schema(LanceModel):
        idx: int  # 索引
        im_file: str  # 图像文件名
        count: int  # 计数
        sim_im_files: List[str]  # 相似图像文件列表

    return Schema


def sanitize_batch(batch, dataset_info):
    """清理推断的输入批次,确保格式和维度正确。"""
    # 将类别转换为扁平整数列表
    batch["cls"] = batch["cls"].flatten().int().tolist()
    # 按类别对边界框和类别进行排序
    box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
    batch["bboxes"] = [box for box, _ in box_cls_pair]  # 更新边界框列表
    batch["cls"] = [cls for _, cls in box_cls_pair]  # 更新类别列表
    # 根据类别索引获取标签名称
    batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
    # 将掩模和关键点转换为列表形式
    batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
    batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
    return batch


def plot_query_result(similar_set, plot_labels=True):
    """
    绘制相似集合中的图像。

    Args:
        similar_set (list): 包含相似数据点的Pyarrow或pandas对象
        plot_labels (bool): 是否绘制标签
    """
    import pandas  # 为更快的'import ultralytics'而导入

    # 如果similar_set是DataFrame,则转换为字典
    similar_set = (
        similar_set.to_dict(orient="list") if isinstance(similar_set, pandas.DataFrame) else similar_set.to_pydict()
    )
    empty_masks = [[[]]]
    empty_boxes = [[]]
    
    # 获取相似集合中的图像、边界框、掩模、关键点和类别
    images = similar_set.get("im_file", [])
    bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
    masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
    kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
    cls = similar_set.get("cls", [])

    plot_size = 640  # 绘图尺寸
    imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
    # 遍历图像列表,并用索引 i 和图像文件路径 imf 迭代
    for i, imf in enumerate(images):
        # 使用 OpenCV 读取图像文件
        im = cv2.imread(imf)
        # 将图像从 BGR 格式转换为 RGB 格式
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        # 获取图像的高度 h 和宽度 w
        h, w = im.shape[:2]
        # 计算缩放比例 r,使得图像可以适应预定的绘图大小 plot_size
        r = min(plot_size / h, plot_size / w)
        # 使用 LetterBox 函数对图像进行处理,并将通道顺序从 HWC 转换为 CHW
        imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1))
        
        # 如果需要绘制标签
        if plot_labels:
            # 如果当前图像存在边界框信息
            if len(bboxes) > i and len(bboxes[i]) > 0:
                # 将边界框坐标根据缩放比例 r 进行调整
                box = np.array(bboxes[i], dtype=np.float32)
                box[:, [0, 2]] *= r
                box[:, [1, 3]] *= r
                plot_boxes.append(box)
            
            # 如果当前图像存在掩码信息
            if len(masks) > i and len(masks[i]) > 0:
                # 取出掩码并使用 LetterBox 处理
                mask = np.array(masks[i], dtype=np.uint8)[0]
                plot_masks.append(LetterBox(plot_size, center=False)(image=mask))
            
            # 如果当前图像存在关键点信息
            if len(kpts) > i and kpts[i] is not None:
                # 取出关键点坐标并根据缩放比例 r 进行调整
                kpt = np.array(kpts[i], dtype=np.float32)
                kpt[:, :, :2] *= r
                plot_kpts.append(kpt)
        
        # 将当前图像索引 i 添加到 batch_idx 列表中,其长度与当前图像的边界框数量相同
        batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
    
    # 将所有处理后的图像堆叠成一个批次 imgs
    imgs = np.stack(imgs, axis=0)
    # 将所有处理后的掩码堆叠成一个批次 masks,如果没有掩码则创建空数组
    masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8)
    # 将所有处理后的关键点堆叠成一个数组 kpts,如果没有关键点则创建空数组
    kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32)
    # 将所有处理后的边界框坐标从 xyxy 格式转换为 xywh 格式,如果没有边界框则创建空数组
    boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32)
    # 将 batch_idx 数组连接起来,形成一个批次索引
    batch_idx = np.concatenate(batch_idx, axis=0)
    # 将类别列表 cls 中所有元素连接成一个数组 cls
    cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
    
    # 调用 plot_images 函数,绘制所有处理后的图像及其相关信息,并返回结果
    return plot_images(
        imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
    )
def prompt_sql_query(query):
    """提示用户输入 SQL 查询,然后使用 OpenAI 模型生成完整的 SQL 查询语句"""

    # 检查是否符合 openai 要求
    check_requirements("openai>=1.6.1")
    # 导入 OpenAI 模块
    from openai import OpenAI

    # 如果 SETTINGS 中未设置 openai_api_key,则提示用户输入
    if not SETTINGS["openai_api_key"]:
        logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
        openai_api_key = getpass.getpass("OpenAI API key: ")
        SETTINGS.update({"openai_api_key": openai_api_key})
    # 创建 OpenAI 对象并使用设置的 API key
    openai = OpenAI(api_key=SETTINGS["openai_api_key"])

    # 准备对话消息列表
    messages = [
        {
            "role": "system",
            "content": """
                You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
                the following schema and a user request. You only need to output the format with fixed selection
                statement that selects everything from "'table'", like `SELECT * from 'table'`

                Schema:
                im_file: string not null
                labels: list<item: string> not null
                child 0, item: string
                cls: list<item: int64> not null
                child 0, item: int64
                bboxes: list<item: list<item: double>> not null
                child 0, item: list<item: double>
                    child 0, item: double
                masks: list<item: list<item: list<item: int64>> not null
                child 0, item: list<item: list<item: int64>>
                    child 0, item: list<item: int64>
                        child 0, item: int64
                keypoints: list<item: list<item: list<item: double>> not null
                child 0, item: list<item: list<item: double>>
                    child 0, item: list<item: double>
                        child 0, item: double
                vector: fixed_size_list<item: float>[256] not null
                child 0, item: float

                Some details about the schema:
                - the "labels" column contains the string values like 'person' and 'dog' for the respective objects
                    in each image
                - the "cls" column contains the integer values on these classes that map them the labels

                Example of a correct query:
                request - Get all data points that contain 2 or more people and at least one dog
                correct query-
                SELECT * FROM 'table' WHERE  ARRAY_LENGTH(cls) >= 2  AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2  AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
             """,
        },
        {"role": "user", "content": f"{query}"},  # 用户输入的查询消息
    ]

    # 调用 OpenAI 模型生成回应
    response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
    return response.choices[0].message.content  # 返回生成的完整 SQL 查询语句

.\yolov8\ultralytics\data\explorer\__init__.py

# 导入 plot_query_result 函数,该函数来自当前目录中的 utils 模块
from .utils import plot_query_result

# 定义 __all__ 列表,指定了本模块中可以被导出的内容,只包括 plot_query_result 函数
__all__ = ["plot_query_result"]

.\yolov8\ultralytics\data\loaders.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import glob  # 导入glob模块,用于获取文件路径列表
import math  # 导入math模块,提供数学计算函数
import os  # 导入os模块,用于与操作系统进行交互
import time  # 导入time模块,提供时间相关函数
from dataclasses import dataclass  # 导入dataclass类,用于创建数据类
from pathlib import Path  # 导入Path类,用于处理路径
from threading import Thread  # 导入Thread类,用于实现多线程操作
from urllib.parse import urlparse  # 导入urlparse函数,用于解析URL

import cv2  # 导入cv2模块,OpenCV库
import numpy as np  # 导入numpy库,用于数值计算
import requests  # 导入requests模块,用于HTTP请求
import torch  # 导入torch模块,PyTorch深度学习库
from PIL import Image  # 导入Image类,Python图像处理库PIL的一部分

from ultralytics.data.utils import FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS  # 导入自定义模块的特定内容
from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, ops  # 导入自定义模块的特定内容
from ultralytics.utils.checks import check_requirements  # 导入自定义模块的特定函数


@dataclass
class SourceTypes:
    """Class to represent various types of input sources for predictions."""
    
    stream: bool = False  # 是否为流类型输入,默认为False
    screenshot: bool = False  # 是否为截图类型输入,默认为False
    from_img: bool = False  # 是否为图像文件类型输入,默认为False
    tensor: bool = False  # 是否为张量类型输入,默认为False


class LoadStreams:
    """
    Stream Loader for various types of video streams, Supports RTSP, RTMP, HTTP, and TCP streams.

    Attributes:
        sources (str): The source input paths or URLs for the video streams.
        vid_stride (int): Video frame-rate stride, defaults to 1.
        buffer (bool): Whether to buffer input streams, defaults to False.
        running (bool): Flag to indicate if the streaming thread is running.
        mode (str): Set to 'stream' indicating real-time capture.
        imgs (list): List of image frames for each stream.
        fps (list): List of FPS for each stream.
        frames (list): List of total frames for each stream.
        threads (list): List of threads for each stream.
        shape (list): List of shapes for each stream.
        caps (list): List of cv2.VideoCapture objects for each stream.
        bs (int): Batch size for processing.

    Methods:
        __init__: Initialize the stream loader.
        update: Read stream frames in daemon thread.
        close: Close stream loader and release resources.
        __iter__: Returns an iterator object for the class.
        __next__: Returns source paths, transformed, and original images for processing.
        __len__: Return the length of the sources object.

    Example:
         ```py
         yolo predict source='rtsp://example.com/media.mp4'
         ```
    """
    def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
        """Initialize instance variables and check for consistent input stream shapes."""
        torch.backends.cudnn.benchmark = True  # faster for fixed-size inference
        self.buffer = buffer  # buffer input streams
        self.running = True  # running flag for Thread
        self.mode = "stream"
        self.vid_stride = vid_stride  # video frame-rate stride

        # Read sources from file or use directly if already a list
        sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
        n = len(sources)  # Number of sources
        self.bs = n  # Set batch size to number of sources
        self.fps = [0] * n  # Initialize frames per second list for each source
        self.frames = [0] * n  # Initialize frame count list for each source
        self.threads = [None] * n  # Initialize threads list for each source
        self.caps = [None] * n  # Initialize video capture objects list for each source
        self.imgs = [[] for _ in range(n)]  # Initialize empty list to store images for each source
        self.shape = [[] for _ in range(n)]  # Initialize empty list to store image shapes for each source
        self.sources = [ops.clean_str(x) for x in sources]  # Clean and store source names for later use

        for i, s in enumerate(sources):  # Loop through each source with index i and source s
            # Start thread to read frames from video stream
            st = f"{i + 1}/{n}: {s}... "

            # Check if source is a YouTube video and convert URL if necessary
            if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}:
                s = get_best_youtube_url(s)

            # Evaluate string if numeric (e.g., '0' for local webcam)
            s = eval(s) if s.isnumeric() else s

            # Raise error if trying to use webcam in Colab or Kaggle environments
            if s == 0 and (IS_COLAB or IS_KAGGLE):
                raise NotImplementedError(
                    "'source=0' webcam not supported in Colab and Kaggle notebooks. "
                    "Try running 'source=0' in a local environment."
                )

            # Initialize video capture object for the current source
            self.caps[i] = cv2.VideoCapture(s)

            # Raise error if video capture object fails to open
            if not self.caps[i].isOpened():
                raise ConnectionError(f"{st}Failed to open {s}")

            # Retrieve and store video properties: width, height, frames per second
            w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
            h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
            fps = self.caps[i].get(cv2.CAP_PROP_FPS)
            
            # Calculate total frames; handle cases where frame count might be 0 or NaN
            self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float("inf")

            # Calculate frames per second, ensuring a minimum of 30 FPS
            self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30

            # Read the first frame to ensure successful connection
            success, im = self.caps[i].read()
            if not success or im is None:
                raise ConnectionError(f"{st}Failed to read images from {s}")

            # Store the first frame and its shape
            self.imgs[i].append(im)
            self.shape[i] = im.shape

            # Start a thread to continuously update frames for the current source
            self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
            LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")

            # Start the thread for reading frames
            self.threads[i].start()

        LOGGER.info("")  # Print a newline for logging clarity
    def update(self, i, cap, stream):
        """
        Read stream `i` frames in daemon thread.
        """
        n, f = 0, self.frames[i]  # 初始化帧号和帧数组
        while self.running and cap.isOpened() and n < (f - 1):
            if len(self.imgs[i]) < 30:  # 保持不超过30帧的图像缓冲
                n += 1
                cap.grab()  # 捕获视频帧,不直接读取,而是先抓取再检索
                if n % self.vid_stride == 0:  # 每 vid_stride 帧执行一次
                    success, im = cap.retrieve()  # 检索已抓取的视频帧
                    if not success:
                        im = np.zeros(self.shape[i], dtype=np.uint8)  # 如果检索失败,创建全零图像
                        LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
                        cap.open(stream)  # 如果信号丢失,重新打开流
                    if self.buffer:
                        self.imgs[i].append(im)  # 将图像帧添加到缓冲区
                    else:
                        self.imgs[i] = [im]  # 替换当前缓冲区的图像帧
            else:
                time.sleep(0.01)  # 等待直到缓冲区为空

    def close(self):
        """
        Close stream loader and release resources.
        """
        self.running = False  # 停止线程的标志
        for thread in self.threads:
            if thread.is_alive():
                thread.join(timeout=5)  # 等待线程结束,设置超时时间
        for cap in self.caps:  # 遍历存储的 VideoCapture 对象
            try:
                cap.release()  # 释放视频捕获对象
            except Exception as e:
                LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}")  # 捕获异常并记录警告信息
        cv2.destroyAllWindows()  # 关闭所有 OpenCV 窗口

    def __iter__(self):
        """
        Iterates through YOLO image feed and re-opens unresponsive streams.
        """
        self.count = -1  # 初始化计数器
        return self

    def __next__(self):
        """
        Returns source paths, transformed and original images for processing.
        """
        self.count += 1  # 计数器自增

        images = []
        for i, x in enumerate(self.imgs):
            # 等待直到每个缓冲区中有帧可用
            while not x:
                if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"):  # 检查线程状态或用户是否按下 'q'
                    self.close()  # 关闭对象
                    raise StopIteration  # 抛出停止迭代异常
                time.sleep(1 / min(self.fps))  # 等待时间间隔,最小 FPS
                x = self.imgs[i]  # 更新缓冲区状态
                if not x:
                    LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}")  # 记录警告信息

            # 从 imgs 缓冲区中获取并移除第一帧图像
            if self.buffer:
                images.append(x.pop(0))
            # 获取最后一帧图像,并清空缓冲区的其余图像帧
            else:
                images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
                x.clear()

        return self.sources, images, [""] * self.bs  # 返回源路径、转换后的图像和原始图像列表

    def __len__(self):
        """
        Return the length of the sources object.
        """
        return self.bs  # 返回源对象的长度,即 batch size
class LoadScreenshots:
    """
    YOLOv8 screenshot dataloader.

    This class manages the loading of screenshot images for processing with YOLOv8.
    Suitable for use with `yolo predict source=screen`.

    Attributes:
        source (str): The source input indicating which screen to capture.
        screen (int): The screen number to capture.
        left (int): The left coordinate for screen capture area.
        top (int): The top coordinate for screen capture area.
        width (int): The width of the screen capture area.
        height (int): The height of the screen capture area.
        mode (str): Set to 'stream' indicating real-time capture.
        frame (int): Counter for captured frames.
        sct (mss.mss): Screen capture object from `mss` library.
        bs (int): Batch size, set to 1.
        monitor (dict): Monitor configuration details.

    Methods:
        __iter__: Returns an iterator object.
        __next__: Captures the next screenshot and returns it.
    """

    def __init__(self, source):
        """Source = [screen_number left top width height] (pixels)."""
        # 检查并确保mss库已经安装
        check_requirements("mss")
        # 导入mss库
        import mss  # noqa

        # 解析source参数,根据参数设置截图的屏幕区域
        source, *params = source.split()
        self.screen, left, top, width, height = 0, None, None, None, None  # default to full screen 0
        if len(params) == 1:
            self.screen = int(params[0])
        elif len(params) == 4:
            left, top, width, height = (int(x) for x in params)
        elif len(params) == 5:
            self.screen, left, top, width, height = (int(x) for x in params)
        
        # 设置截图模式为实时流
        self.mode = "stream"
        # 初始化帧计数器
        self.frame = 0
        # 创建mss对象用于屏幕截图
        self.sct = mss.mss()
        # 设置批处理大小为1
        self.bs = 1
        # 设置帧率为30帧每秒
        self.fps = 30

        # 解析monitor参数,根据屏幕和截图区域设置监视器配置
        monitor = self.sct.monitors[self.screen]
        self.top = monitor["top"] if top is None else (monitor["top"] + top)
        self.left = monitor["left"] if left is None else (monitor["left"] + left)
        self.width = width or monitor["width"]
        self.height = height or monitor["height"]
        self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}

    def __iter__(self):
        """Returns an iterator of the object."""
        return self

    def __next__(self):
        """mss screen capture: get raw pixels from the screen as np array."""
        # 使用mss对象获取屏幕截图,并将像素转换为numpy数组
        im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3]  # BGRA to BGR
        s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "

        # 增加帧计数
        self.frame += 1
        # 返回截图相关信息
        return [str(self.screen)], [im0], [s]  # screen, img, string
    # 定义一个数据加载器类,用于加载图像和视频文件
    class Dataloader:
        """
        Attributes:
            files (list): List of image and video file paths.
            nf (int): Total number of files (images and videos).
            video_flag (list): Flags indicating whether a file is a video (True) or an image (False).
            mode (str): Current mode, 'image' or 'video'.
            vid_stride (int): Stride for video frame-rate, defaults to 1.
            bs (int): Batch size, set to 1 for this class.
            cap (cv2.VideoCapture): Video capture object for OpenCV.
            frame (int): Frame counter for video.
            frames (int): Total number of frames in the video.
            count (int): Counter for iteration, initialized at 0 during `__iter__()`.

        Methods:
            _new_video(path): Create a new cv2.VideoCapture object for a given video path.
        """

        def __init__(self, path, batch=1, vid_stride=1):
            """Initialize the Dataloader and raise FileNotFoundError if file not found."""
            parent = None
            if isinstance(path, str) and Path(path).suffix == ".txt":  # *.txt file with img/vid/dir on each line
                parent = Path(path).parent
                path = Path(path).read_text().splitlines()  # list of sources
            files = []
            for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
                a = str(Path(p).absolute())  # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
                if "*" in a:
                    files.extend(sorted(glob.glob(a, recursive=True)))  # glob
                elif os.path.isdir(a):
                    files.extend(sorted(glob.glob(os.path.join(a, "*.*"))))  # dir
                elif os.path.isfile(a):
                    files.append(a)  # files (absolute or relative to CWD)
                elif parent and (parent / p).is_file():
                    files.append(str((parent / p).absolute()))  # files (relative to *.txt file parent)
                else:
                    raise FileNotFoundError(f"{p} does not exist")

            # Define files as images or videos
            images, videos = [], []
            for f in files:
                suffix = f.split(".")[-1].lower()  # Get file extension without the dot and lowercase
                if suffix in IMG_FORMATS:
                    images.append(f)
                elif suffix in VID_FORMATS:
                    videos.append(f)
            ni, nv = len(images), len(videos)

            self.files = images + videos
            self.nf = ni + nv  # number of files
            self.ni = ni  # number of images
            self.video_flag = [False] * ni + [True] * nv
            self.mode = "image"
            self.vid_stride = vid_stride  # video frame-rate stride
            self.bs = batch
            if any(videos):
                self._new_video(videos[0])  # new video
            else:
                self.cap = None
            if self.nf == 0:
                raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")

        def __iter__(self):
            """Returns an iterator object for VideoStream or ImageFolder."""
            self.count = 0
            return self
    def __next__(self):
        """Returns the next batch of images or video frames along with their paths and metadata."""
        paths, imgs, info = [], [], []  # 初始化空列表,用于存储路径、图像/视频帧和元数据信息
        while len(imgs) < self.bs:  # 当图像/视频帧列表长度小于批次大小时执行循环
            if self.count >= self.nf:  # 如果计数器超过文件总数,则表示文件列表结束
                if imgs:
                    return paths, imgs, info  # 返回最后一个不完整的批次
                else:
                    raise StopIteration  # 否则抛出迭代结束异常

            path = self.files[self.count]  # 获取当前文件路径
            if self.video_flag[self.count]:  # 检查当前文件是否为视频
                self.mode = "video"  # 设置模式为视频
                if not self.cap or not self.cap.isOpened():  # 如果视频捕获对象不存在或未打开
                    self._new_video(path)  # 创建新的视频捕获对象

                for _ in range(self.vid_stride):  # 循环抓取视频帧
                    success = self.cap.grab()
                    if not success:
                        break  # 如果抓取失败,则退出循环

                if success:  # 如果抓取成功
                    success, im0 = self.cap.retrieve()  # 检索抓取的视频帧
                    if success:
                        self.frame += 1  # 帧数加一
                        paths.append(path)  # 添加路径到列表
                        imgs.append(im0)  # 添加图像帧到列表
                        info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ")  # 添加视频信息到列表
                        if self.frame == self.frames:  # 如果达到视频帧数的最大值
                            self.count += 1  # 计数器加一
                            self.cap.release()  # 释放视频捕获对象
                else:
                    # 如果当前视频结束或打开失败,移动到下一个文件
                    self.count += 1
                    if self.cap:
                        self.cap.release()  # 释放视频捕获对象
                    if self.count < self.nf:
                        self._new_video(self.files[self.count])  # 创建新的视频捕获对象
            else:
                self.mode = "image"  # 设置模式为图像
                im0 = cv2.imread(path)  # 读取图像(BGR格式)
                if im0 is None:
                    LOGGER.warning(f"WARNING ⚠️ Image Read Error {path}")  # 如果图像读取失败,记录警告信息
                else:
                    paths.append(path)  # 添加路径到列表
                    imgs.append(im0)  # 添加图像到列表
                    info.append(f"image {self.count + 1}/{self.nf} {path}: ")  # 添加图像信息到列表
                self.count += 1  # 计数器加一,移动到下一个文件
                if self.count >= self.ni:  # 如果计数器超过图像总数
                    break  # 跳出循环,结束图像列表的读取

        return paths, imgs, info  # 返回路径、图像/视频帧和元数据信息列表

    def _new_video(self, path):
        """Creates a new video capture object for the given path."""
        self.frame = 0  # 初始化帧数
        self.cap = cv2.VideoCapture(path)  # 创建新的视频捕获对象
        self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))  # 获取视频帧率
        if not self.cap.isOpened():
            raise FileNotFoundError(f"Failed to open video {path}")  # 如果视频打开失败,抛出文件未找到异常
        self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)  # 计算视频帧数

    def __len__(self):
        """Returns the number of batches in the object."""
        return math.ceil(self.nf / self.bs)  # 返回对象中批次的数量,向上取整
    """
    Load images from PIL and Numpy arrays for batch processing.

    This class is designed to manage loading and pre-processing of image data from both PIL and Numpy formats.
    It performs basic validation and format conversion to ensure that the images are in the required format for
    downstream processing.

    Attributes:
        paths (list): List of image paths or autogenerated filenames.
        im0 (list): List of images stored as Numpy arrays.
        mode (str): Type of data being processed, defaults to 'image'.
        bs (int): Batch size, equivalent to the length of `im0`.

    Methods:
        _single_check(im): Validate and format a single image to a Numpy array.
    """

    def __init__(self, im0):
        """Initialize PIL and Numpy Dataloader."""
        if not isinstance(im0, list):
            im0 = [im0]
        # Generate filenames or use existing ones from input images
        self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
        # Validate and convert each image in `im0` to Numpy arrays
        self.im0 = [self._single_check(im) for im in im0]
        # Set the processing mode to 'image'
        self.mode = "image"
        # Set the batch size to the number of images in `im0`
        self.bs = len(self.im0)

    @staticmethod
    def _single_check(im):
        """Validate and format an image to numpy array."""
        # Ensure `im` is either a PIL.Image or np.ndarray
        assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
        if isinstance(im, Image.Image):
            # Convert PIL.Image to RGB mode if not already
            if im.mode != "RGB":
                im = im.convert("RGB")
            # Convert PIL.Image to Numpy array and reverse channels
            im = np.asarray(im)[:, :, ::-1]
            im = np.ascontiguousarray(im)  # Make sure the array is contiguous
        return im

    def __len__(self):
        """Returns the length of the 'im0' attribute."""
        return len(self.im0)

    def __next__(self):
        """Returns batch paths, images, processed images, None, ''."""
        if self.count == 1:  # loop only once as it's batch inference
            raise StopIteration
        self.count += 1
        return self.paths, self.im0, [""] * self.bs

    def __iter__(self):
        """Enables iteration for class LoadPilAndNumpy."""
        self.count = 0
        return self


class LoadTensor:
    """
    Load images from torch.Tensor data.

    This class manages the loading and pre-processing of image data from PyTorch tensors for further processing.

    Attributes:
        im0 (torch.Tensor): The input tensor containing the image(s).
        bs (int): Batch size, inferred from the shape of `im0`.
        mode (str): Current mode, set to 'image'.
        paths (list): List of image paths or filenames.
        count (int): Counter for iteration, initialized at 0 during `__iter__()`.

    Methods:
        _single_check(im, stride): Validate and possibly modify the input tensor.
    """

    def __init__(self, im0) -> None:
        """Initialize Tensor Dataloader."""
        # Validate and store the input tensor `im0`
        self.im0 = self._single_check(im0)
        # Infer batch size from the first dimension of the tensor
        self.bs = self.im0.shape[0]
        # Set the processing mode to 'image'
        self.mode = "image"
        # Generate filenames or use existing ones from input tensors
        self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]

    @staticmethod
    # 验证并将图像格式化为 torch.Tensor
    def _single_check(im, stride=32):
        """Validate and format an image to torch.Tensor."""
        # 构建警告信息,确保输入的 torch.Tensor 应为 BCHW 格式,即 shape(1, 3, 640, 640),且能被指定的步长 stride 整除。如果不兼容则抛出错误。
        s = (
            f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
            f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
        )
        # 检查输入图像的维度是否为4维,如果不是,则尝试在第0维度上增加一个维度。
        if len(im.shape) != 4:
            if len(im.shape) != 3:
                raise ValueError(s)
            # 记录警告日志,表示输入图像维度不符合要求
            LOGGER.warning(s)
            im = im.unsqueeze(0)
        # 检查图像的高度和宽度是否能被指定的步长整除,如果不能则抛出错误。
        if im.shape[2] % stride or im.shape[3] % stride:
            raise ValueError(s)
        # 如果图像中的最大值超过了1.0加上 torch.float32 类型的误差允许值,记录警告日志,并将输入图像转换为 float 类型后归一化到0.0-1.0范围内。
        if im.max() > 1.0 + torch.finfo(im.dtype).eps:  # torch.float32 eps is 1.2e-07
            LOGGER.warning(
                f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. "
                f"Dividing input by 255."
            )
            im = im.float() / 255.0

        return im

    # 返回一个迭代器对象
    def __iter__(self):
        """Returns an iterator object."""
        self.count = 0
        return self

    # 返回迭代器的下一个项目
    def __next__(self):
        """Return next item in the iterator."""
        # 如果计数器达到1,抛出 StopIteration 异常
        if self.count == 1:
            raise StopIteration
        # 增加计数器的值,并返回路径、im0 和空列表组成的元组
        self.count += 1
        return self.paths, self.im0, [""] * self.bs

    # 返回批处理大小
    def __len__(self):
        """Returns the batch size."""
        return self.bs
def autocast_list(source):
    """
    Merges a list of source of different types into a list of numpy arrays or PIL images.

    Args:
        source (list): A list containing elements of various types like filenames, URIs, PIL Images, or numpy arrays.

    Returns:
        list: A list containing PIL Images or numpy arrays converted from the input sources.

    Raises:
        TypeError: If the input element is not of a supported type.

    """
    files = []
    for im in source:
        if isinstance(im, (str, Path)):  # filename or uri
            # Open the image from URL if it starts with "http", otherwise directly open as file
            files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im))
        elif isinstance(im, (Image.Image, np.ndarray)):  # PIL or np Image
            files.append(im)
        else:
            raise TypeError(
                f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
                f"See https://docs.ultralytics.com/modes/predict for supported source types."
            )

    return files


def get_best_youtube_url(url, method="pytube"):
    """
    Retrieves the URL of the best quality MP4 video stream from a given YouTube video.

    Args:
        url (str): The URL of the YouTube video.
        method (str): The method to use for extracting video info. Default is "pytube". Other options are "pafy" and
            "yt-dlp".

    Returns:
        str: The URL of the best quality MP4 video stream, or None if no suitable stream is found.

    """
    if method == "pytube":
        # Ensure compatibility with pytubefix library version
        check_requirements("pytubefix>=6.5.2")
        from pytubefix import YouTube

        # Fetch video streams filtered by MP4 format and only video streams
        streams = YouTube(url).streams.filter(file_extension="mp4", only_video=True)
        # Sort streams by resolution in descending order
        streams = sorted(streams, key=lambda s: s.resolution, reverse=True)
        for stream in streams:
            # Check if stream resolution is at least 1080p
            if stream.resolution and int(stream.resolution[:-1]) >= 1080:
                return stream.url

    elif method == "pafy":
        # Ensure necessary libraries are installed and import pafy
        check_requirements(("pafy", "youtube_dl==2020.12.2"))
        import pafy  # noqa

        # Fetch the best available MP4 video stream URL
        return pafy.new(url).getbestvideo(preftype="mp4").url
    # 如果下载方法为 "yt-dlp",则执行以下代码块
    elif method == "yt-dlp":
        # 检查是否满足使用 yt-dlp 的要求
        check_requirements("yt-dlp")
        # 导入 yt_dlp 模块
        import yt_dlp

        # 使用 yt-dlp.YoutubeDL 创建一个实例 ydl,并设置参数 {"quiet": True}
        with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
            # 调用 extract_info 方法从指定的 url 提取视频信息,但不下载视频
            info_dict = ydl.extract_info(url, download=False)

        # 遍历视频格式信息列表(反向遍历,因为最佳格式通常在最后)
        for f in reversed(info_dict.get("formats", [])):
            # 检查当前格式是否满足条件:视频编解码器存在、无音频、扩展名为 mp4、至少 1920x1080 大小
            good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
            if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
                # 如果符合条件,返回该格式的视频 URL
                return f.get("url")
# 定义常量 LOADERS,包含四个不同的加载器类
LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots)

.\yolov8\ultralytics\data\split_dota.py

# 导入必要的库和模块
import itertools  # 导入 itertools 库,用于迭代操作
from glob import glob  # 从 glob 模块中导入 glob 函数,用于文件路径的匹配
from math import ceil  # 导入 math 模块中的 ceil 函数,用于向上取整
from pathlib import Path  # 导入 pathlib 模块中的 Path 类,用于处理路径操作

import cv2  # 导入 OpenCV 库
import numpy as np  # 导入 NumPy 库
from PIL import Image  # 从 PIL 库中导入 Image 模块
from tqdm import tqdm  # 导入 tqdm 库,用于显示进度条

from ultralytics.data.utils import exif_size, img2label_paths  # 导入 ultralytics.data.utils 中的函数
from ultralytics.utils.checks import check_requirements  # 从 ultralytics.utils.checks 导入 check_requirements 函数

# 检查并确保安装了 shapely 库
check_requirements("shapely")
from shapely.geometry import Polygon  # 导入 shapely 库中的 Polygon 类


def bbox_iof(polygon1, bbox2, eps=1e-6):
    """
    Calculate iofs between bbox1 and bbox2.

    Args:
        polygon1 (np.ndarray): Polygon coordinates, (n, 8).
        bbox2 (np.ndarray): Bounding boxes, (n ,4).
    """
    polygon1 = polygon1.reshape(-1, 4, 2)  # 将 polygon1 重新组织成 (n, 4, 2) 的数组
    lt_point = np.min(polygon1, axis=-2)  # 计算 polygon1 中每个多边形的左上角点
    rb_point = np.max(polygon1, axis=-2)  # 计算 polygon1 中每个多边形的右下角点
    bbox1 = np.concatenate([lt_point, rb_point], axis=-1)  # 将左上角和右下角点合并为 bbox1

    lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2])  # 计算左上角点的最大值
    rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:])  # 计算右下角点的最小值
    wh = np.clip(rb - lt, 0, np.inf)  # 计算宽度和高度,并将其限制在非负范围内
    h_overlaps = wh[..., 0] * wh[..., 1]  # 计算高度上的重叠区域面积

    left, top, right, bottom = (bbox2[..., i] for i in range(4))  # 提取 bbox2 的左上右下边界坐标
    polygon2 = np.stack([left, top, right, top, right, bottom, left, bottom], axis=-1).reshape(-1, 4, 2)  # 重新组织 polygon2

    sg_polys1 = [Polygon(p) for p in polygon1]  # 创建 polygon1 的多边形对象列表
    sg_polys2 = [Polygon(p) for p in polygon2]  # 创建 polygon2 的多边形对象列表
    overlaps = np.zeros(h_overlaps.shape)  # 创建全零数组用于存储重叠面积
    for p in zip(*np.nonzero(h_overlaps)):
        overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area  # 计算多边形的交集面积
    unions = np.array([p.area for p in sg_polys1], dtype=np.float32)  # 计算多边形的联合面积
    unions = unions[..., None]

    unions = np.clip(unions, eps, np.inf)  # 将 unions 数组限制在 eps 到无穷大的范围内
    outputs = overlaps / unions  # 计算 IOF(Intersection over Full)
    if outputs.ndim == 1:
        outputs = outputs[..., None]
    return outputs  # 返回 IOF 数组


def load_yolo_dota(data_root, split="train"):
    """
    Load DOTA dataset.

    Args:
        data_root (str): Data root.
        split (str): The split data set, could be train or val.

    Notes:
        The directory structure assumed for the DOTA dataset:
            - data_root
                - images
                    - train
                    - val
                - labels
                    - train
                    - val
    """
    assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
    im_dir = Path(data_root) / "images" / split  # 图像目录路径
    assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
    im_files = glob(str(Path(data_root) / "images" / split / "*"))  # 获取图像文件列表
    lb_files = img2label_paths(im_files)  # 根据图像文件获取标签文件列表
    annos = []
    for im_file, lb_file in zip(im_files, lb_files):
        w, h = exif_size(Image.open(im_file))  # 获取图像的宽度和高度
        with open(lb_file) as f:
            lb = [x.split() for x in f.read().strip().splitlines() if len(x)]  # 读取标签文件并处理成列表
            lb = np.array(lb, dtype=np.float32)  # 转换为 NumPy 数组
        annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file))  # 将图像信息和标签信息添加到注释列表中
    return annos  # 返回注释列表


def get_windows(im_size, crop_sizes=(1024,), gaps=(200,), im_rate_thr=0.6, eps=0.01):
    """
    Get the coordinates of windows.
    """
    Args:
        im_size (tuple): Original image size, (h, w).
        crop_sizes (List(int)): Crop size of windows.
        gaps (List(int)): Gap between crops.
        im_rate_thr (float): Threshold of windows areas divided by image areas.
        eps (float): Epsilon value for math operations.
    """
    # 解包图像尺寸
    h, w = im_size
    # 初始化空列表用于存储窗口坐标
    windows = []
    # 遍历crop_sizes和gaps列表,分别为crop_size和gap赋值,生成窗口坐标
    for crop_size, gap in zip(crop_sizes, gaps):
        # 断言crop_size大于gap,否则抛出异常
        assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
        # 计算步长
        step = crop_size - gap

        # 计算在宽度方向上的窗口数量及其起始位置
        xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
        xs = [step * i for i in range(xn)]
        # 调整最后一个窗口的位置,确保不超出图像边界
        if len(xs) > 1 and xs[-1] + crop_size > w:
            xs[-1] = w - crop_size

        # 计算在高度方向上的窗口数量及其起始位置
        yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1)
        ys = [step * i for i in range(yn)]
        # 调整最后一个窗口的位置,确保不超出图像边界
        if len(ys) > 1 and ys[-1] + crop_size > h:
            ys[-1] = h - crop_size

        # 使用itertools生成所有可能的窗口坐标,并转换为numpy数组
        start = np.array(list(itertools.product(xs, ys)), dtype=np.int64)
        stop = start + crop_size
        # 将起始和结束坐标连接起来形成完整的窗口坐标
        windows.append(np.concatenate([start, stop], axis=1))
    
    # 将所有窗口坐标连接成一个numpy数组
    windows = np.concatenate(windows, axis=0)

    # 复制窗口坐标,用于进行边界裁剪
    im_in_wins = windows.copy()
    # 对窗口坐标的x坐标进行裁剪,确保不超出图像宽度边界
    im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w)
    # 对窗口坐标的y坐标进行裁剪,确保不超出图像高度边界
    im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h)
    
    # 计算每个窗口在原始图像中的面积
    im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1])
    # 计算每个窗口的面积
    win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1])
    # 计算每个窗口的面积比率
    im_rates = im_areas / win_areas
    
    # 如果所有窗口的面积比率都小于等于阈值im_rate_thr,则选择最大比率的窗口
    if not (im_rates > im_rate_thr).any():
        max_rate = im_rates.max()
        im_rates[abs(im_rates - max_rate) < eps] = 1
    
    # 返回符合条件的窗口坐标
    return windows[im_rates > im_rate_thr]
# 将给定窗口中的对象分别提取出来。
def get_window_obj(anno, windows, iof_thr=0.7):
    """Get objects for each window."""
    # 获取原始图像的高度和宽度
    h, w = anno["ori_size"]
    # 获取标签数据
    label = anno["label"]
    # 如果标签非空,则对标签中的坐标进行宽度和高度的缩放
    if len(label):
        label[:, 1::2] *= w
        label[:, 2::2] *= h
        # 计算每个窗口与标签框之间的重叠度
        iofs = bbox_iof(label[:, 1:], windows)
        # 根据重叠度阈值筛选出符合条件的标签框,组成列表
        return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))]  # window_anns
    else:
        # 如果标签为空,则返回空的数组
        return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))]  # window_anns


# 裁剪图像并保存新的标签
def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
    """
    Crop images and save new labels.

    Args:
        anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
        windows (list): A list of windows coordinates.
        window_objs (list): A list of labels inside each window.
        im_dir (str): The output directory path of images.
        lb_dir (str): The output directory path of labels.

    Notes:
        The directory structure assumed for the DOTA dataset:
            - data_root
                - images
                    - train
                    - val
                - labels
                    - train
                    - val
    """
    # 读取原始图像
    im = cv2.imread(anno["filepath"])
    # 获取图像文件名的基本部分
    name = Path(anno["filepath"]).stem
    # 遍历每个窗口并进行图像裁剪和保存
    for i, window in enumerate(windows):
        # 解析窗口的起始和结束坐标
        x_start, y_start, x_stop, y_stop = window.tolist()
        # 生成新的文件名,包含窗口大小和起始坐标信息
        new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
        # 根据窗口坐标裁剪图像
        patch_im = im[y_start:y_stop, x_start:x_stop]
        # 获取裁剪后图像的高度和宽度
        ph, pw = patch_im.shape[:2]

        # 将裁剪后的图像保存为 JPEG 文件
        cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im)
        # 获取当前窗口对应的标签
        label = window_objs[i]
        # 如果标签为空,则跳过当前窗口
        if len(label) == 0:
            continue
        # 调整标签的坐标,使其相对于裁剪后的图像
        label[:, 1::2] -= x_start
        label[:, 2::2] -= y_start
        label[:, 1::2] /= pw
        label[:, 2::2] /= ph

        # 将调整后的标签保存到文本文件中
        with open(Path(lb_dir) / f"{new_name}.txt", "w") as f:
            for lb in label:
                # 格式化标签的坐标信息,并写入文件
                formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]]
                f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")


# 分割图像和标签
def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=(1024,), gaps=(200,)):
    """
    Split both images and labels.

    Notes:
        The directory structure assumed for the DOTA dataset:
            - data_root
                - images
                    - split
                - labels
                    - split
        and the output directory structure is:
            - save_dir
                - images
                    - split
                - labels
                    - split
    """
    # 构建输出图像和标签的目录结构
    im_dir = Path(save_dir) / "images" / split
    im_dir.mkdir(parents=True, exist_ok=True)
    lb_dir = Path(save_dir) / "labels" / split
    lb_dir.mkdir(parents=True, exist_ok=True)

    # 加载 YOLO 格式的 DOTA 数据集的注释信息
    annos = load_yolo_dota(data_root, split=split)
    # 使用 tqdm 迭代处理每个标注对象 anno
    for anno in tqdm(annos, total=len(annos), desc=split):
        # 根据原始大小和给定的裁剪尺寸和间隔,获取裁剪窗口列表
        windows = get_windows(anno["ori_size"], crop_sizes, gaps)
        # 根据标注信息和裁剪窗口列表,获取窗口对象列表
        window_objs = get_window_obj(anno, windows)
        # 对原始图像进行裁剪并保存裁剪结果,保存到指定的图像和标签目录
        crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
# 定义一个函数,用于将 DOTA 数据集的训练和验证集进行分割
def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
    """
    Split train and val set of DOTA.

    Notes:
        The directory structure assumed for the DOTA dataset:
            - data_root
                - images
                    - train
                    - val
                - labels
                    - train
                    - val
        and the output directory structure is:
            - save_dir
                - images
                    - train
                    - val
                - labels
                    - train
                    - val

    Parameters:
        data_root (str): 数据集的根目录路径
        save_dir (str): 分割后数据集的保存路径
        crop_size (int): 裁剪尺寸,默认为 1024
        gap (int): 裁剪间隙,默认为 200
        rates (tuple): 裁剪尺寸和间隙的比例因子,默认为 (1.0,)

    Returns:
        None
    """
    # 初始化裁剪尺寸列表和间隙列表
    crop_sizes, gaps = [], []
    # 根据比例因子计算实际裁剪尺寸和间隙
    for r in rates:
        crop_sizes.append(int(crop_size / r))
        gaps.append(int(gap / r))
    # 分别处理训练集和验证集
    for split in ["train", "val"]:
        # 调用函数处理每个数据集的图片和标签
        split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)


# 定义一个函数,用于将 DOTA 数据集的测试集进行分割
def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
    """
    Split test set of DOTA, labels are not included within this set.

    Notes:
        The directory structure assumed for the DOTA dataset:
            - data_root
                - images
                    - test
        and the output directory structure is:
            - save_dir
                - images
                    - test

    Parameters:
        data_root (str): 数据集的根目录路径
        save_dir (str): 分割后数据集的保存路径
        crop_size (int): 裁剪尺寸,默认为 1024
        gap (int): 裁剪间隙,默认为 200
        rates (tuple): 裁剪尺寸和间隙的比例因子,默认为 (1.0,)

    Returns:
        None
    """
    # 初始化裁剪尺寸列表和间隙列表
    crop_sizes, gaps = [], []
    # 根据比例因子计算实际裁剪尺寸和间隙
    for r in rates:
        crop_sizes.append(int(crop_size / r))
        gaps.append(int(gap / r))
    # 确定保存测试集图片的路径并创建目录
    save_dir = Path(save_dir) / "images" / "test"
    save_dir.mkdir(parents=True, exist_ok=True)

    # 获取测试集图片所在目录并检查是否存在
    im_dir = Path(data_root) / "images" / "test"
    assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
    # 获取所有测试集图片文件列表
    im_files = glob(str(im_dir / "*"))
    # 遍历测试集图片文件
    for im_file in tqdm(im_files, total=len(im_files), desc="test"):
        # 获取图片的原始尺寸
        w, h = exif_size(Image.open(im_file))
        # 根据裁剪尺寸和间隙获取窗口列表
        windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
        # 读取图片文件
        im = cv2.imread(im_file)
        # 获取图片文件名(不包含扩展名)
        name = Path(im_file).stem
        # 遍历每个窗口并处理
        for window in windows:
            x_start, y_start, x_stop, y_stop = window.tolist()
            # 构造新的文件名,包含窗口尺寸和起始位置信息
            new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
            # 裁剪图像并保存
            patch_im = im[y_start:y_stop, x_start:x_stop]
            cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im)


if __name__ == "__main__":
    # 调用函数进行训练集和验证集的分割
    split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split")
    # 调用函数进行测试集的分割
    split_test(data_root="DOTAv2", save_dir="DOTAv2-split")

.\yolov8\ultralytics\data\utils.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import contextlib
import hashlib
import json
import os
import random
import subprocess
import time
import zipfile
from multiprocessing.pool import ThreadPool
from pathlib import Path
from tarfile import is_tarfile

import cv2
import numpy as np
from PIL import Image, ImageOps

# 导入自定义模块和函数
from ultralytics.nn.autobackend import check_class_names
from ultralytics.utils import (
    DATASETS_DIR,
    LOGGER,
    NUM_THREADS,
    ROOT,
    SETTINGS_YAML,
    TQDM,
    clean_url,
    colorstr,
    emojis,
    is_dir_writeable,
    yaml_load,
    yaml_save,
)
# 导入数据校验函数和下载函数
from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file
# 导入操作函数
from ultralytics.utils.ops import segments2boxes

# 设置帮助链接
HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
# 定义支持的图片格式和视频格式
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp"}  # image suffixes
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"}  # video suffixes
# 确定是否启用内存固定标记,根据环境变量PIN_MEMORY的值
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true"  # global pin_memory for dataloaders
# 格式帮助信息
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"


def img2label_paths(img_paths):
    """Define label paths as a function of image paths."""
    # 定义图片路径和标签路径的转换关系
    sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}"  # /images/, /labels/ substrings
    return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]


def get_hash(paths):
    """Returns a single hash value of a list of paths (files or dirs)."""
    # 计算路径列表中文件或目录的总大小
    size = sum(os.path.getsize(p) for p in paths if os.path.exists(p))  # sizes
    # 使用SHA-256算法计算路径列表的哈希值
    h = hashlib.sha256(str(size).encode())  # hash sizes
    h.update("".join(paths).encode())  # hash paths
    return h.hexdigest()  # return hash


def exif_size(img: Image.Image):
    """Returns exif-corrected PIL size."""
    s = img.size  # (width, height)
    if img.format == "JPEG":  # only support JPEG images
        # 尝试获取图像的EXIF信息,并根据EXIF信息修正图像尺寸
        with contextlib.suppress(Exception):
            exif = img.getexif()
            if exif:
                rotation = exif.get(274, None)  # the EXIF key for the orientation tag is 274
                if rotation in {6, 8}:  # rotation 270 or 90
                    s = s[1], s[0]
    return s


def verify_image(args):
    """Verify one image."""
    (im_file, cls), prefix = args
    # 初始化计数器和消息字符串
    nf, nc, msg = 0, 0, ""
    try:
        # 尝试打开图像文件
        im = Image.open(im_file)
        # 使用PIL库验证图像文件
        im.verify()  # PIL verify
        # 获取图像的尺寸信息
        shape = exif_size(im)  # image size
        # 调整尺寸信息的顺序为宽度在前,高度在后
        shape = (shape[1], shape[0])  # hw
        # 断言图像的宽度和高度大于9像素
        assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
        # 断言图像的格式在允许的图像格式列表中
        assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
        # 如果图像格式是JPEG,则进一步检查是否损坏
        if im.format.lower() in {"jpg", "jpeg"}:
            # 使用二进制模式打开文件,定位到文件末尾的倒数第二个字节
            with open(im_file, "rb") as f:
                f.seek(-2, 2)
                # 检查文件末尾两个字节是否为JPEG文件的结束标记
                if f.read() != b"\xff\xd9":  # corrupt JPEG
                    # 修复并保存损坏的JPEG文件
                    ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
                    # 生成警告信息
                    msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
        # 如果没有异常发生,设置nf为1
        nf = 1
    except Exception as e:
        # 捕获异常,并设置nc为1,生成警告信息
        nc = 1
        msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
    # 返回结果元组
    return (im_file, cls), nf, nc, msg
# 验证单个图像-标签对的有效性
def verify_image_label(args):
    # 解包参数:图像文件路径、标签文件路径、前缀、关键点、类别数、关键点数、维度数
    im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
    # 初始化计数器和消息变量
    # nm: 缺失的数量
    # nf: 发现的数量
    # ne: 空的数量
    # nc: 损坏的数量
    # msg: 信息字符串
    # segments: 段
    # keypoints: 关键点
    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None

    # 捕获任何异常并记录为损坏的图像/标签对
    except Exception as e:
        # 标记为损坏的数量增加
        nc = 1
        # 设置消息内容,标记文件和具体的错误信息
        msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
        # 返回空的计数和消息,其余变量为 None
        return [None, None, None, None, None, nm, nf, ne, nc, msg]


def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
    """
    将多边形列表转换为指定图像尺寸的二进制掩码。

    Args:
        imgsz (tuple): 图像的大小,格式为 (height, width)。
        polygons (list[np.ndarray]): 多边形列表。每个多边形是一个形状为 [N, M] 的数组,
                                     其中 N 是多边形的数量,M 是点的数量,满足 M % 2 = 0。
        color (int, optional): 在掩码中填充多边形的颜色值。默认为 1。
        downsample_ratio (int, optional): 缩小掩码的因子。默认为 1。

    Returns:
        (np.ndarray): 指定图像尺寸的二进制掩码,填充了多边形。
    """
    # 创建一个全零数组作为掩码
    mask = np.zeros(imgsz, dtype=np.uint8)
    # 将多边形列表转换为 numpy 数组,类型为 int32
    polygons = np.asarray(polygons, dtype=np.int32)
    # 重新整形多边形数组以便填充多边形的顶点
    polygons = polygons.reshape((polygons.shape[0], -1, 2))
    # 使用指定的颜色值填充多边形到掩码中
    cv2.fillPoly(mask, polygons, color=color)
    # 计算缩小后的掩码尺寸
    nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
    # 返回缩小后的掩码,保持与原掩码相同的填充方法
    return cv2.resize(mask, (nw, nh))


def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
    """
    将多边形列表转换为指定图像尺寸的一组二进制掩码。

    Args:
        imgsz (tuple): 图像的大小,格式为 (height, width)。
        polygons (list[np.ndarray]): 多边形列表。每个多边形是一个形状为 [N, M] 的数组,
                                     其中 N 是多边形的数量,M 是点的数量,满足 M % 2 = 0。
        color (int): 在掩码中填充多边形的颜色值。
        downsample_ratio (int, optional): 缩小每个掩码的因子。默认为 1。

    Returns:
        (np.ndarray): 指定图像尺寸的一组二进制掩码,填充了多边形。
    """
    # 对多边形列表中的每个多边形,调用 polygon2mask 函数生成掩码数组,并返回为 numpy 数组
    return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])


def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
    """
    返回一个 (640, 640) 的重叠掩码。

    Args:
        imgsz (tuple): 图像的大小,格式为 (height, width)。
        segments (list): 段列表。
        downsample_ratio (int, optional): 缩小掩码的因子。默认为 1。

    Returns:
        np.ndarray: 指定图像尺寸的重叠掩码。
    """
    # 创建一个全零数组作为掩码,尺寸根据缩小因子调整
    masks = np.zeros(
        (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
        dtype=np.int32 if len(segments) > 255 else np.uint8,
    )
    # 初始化一个区域列表
    areas = []
    # 初始化一个段列表
    ms = []
    # 对于每个分割段落进行迭代
    for si in range(len(segments)):
        # 根据分割段落创建一个二进制掩码
        mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
        # 将生成的掩码添加到掩码列表中
        ms.append(mask)
        # 计算掩码的像素总数,并将其添加到面积列表中
        areas.append(mask.sum())
    
    # 将面积列表转换为 NumPy 数组
    areas = np.asarray(areas)
    # 按照面积大小降序排列索引
    index = np.argsort(-areas)
    # 根据排序后的索引重新排列掩码列表
    ms = np.array(ms)[index]
    
    # 对每个分割段落再次进行迭代
    for i in range(len(segments)):
        # 将重新排序的掩码乘以当前索引加一,生成最终的分割掩码
        mask = ms[i] * (i + 1)
        # 将生成的分割掩码加到总掩码中
        masks = masks + mask
        # 对总掩码进行截断,确保像素值在指定范围内
        masks = np.clip(masks, a_min=0, a_max=i + 1)
    
    # 返回最终生成的总掩码和排序后的索引
    return masks, index
def find_dataset_yaml(path: Path) -> Path:
    """
    Find and return the YAML file associated with a Detect, Segment or Pose dataset.

    This function searches for a YAML file at the root level of the provided directory first, and if not found, it
    performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError
    is raised if no YAML file is found or if multiple YAML files are found.

    Args:
        path (Path): The directory path to search for the YAML file.

    Returns:
        (Path): The path of the found YAML file.
    """
    # Attempt to find YAML files at the root level first, otherwise perform a recursive search
    files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml"))  # try root level first and then recursive
    
    # Ensure at least one YAML file is found; otherwise, raise an AssertionError
    assert files, f"No YAML file found in '{path.resolve()}'"
    
    # If multiple YAML files are found, filter to prefer those with the same stem as the provided path
    if len(files) > 1:
        files = [f for f in files if f.stem == path.stem]  # prefer *.yaml files that match
    
    # Ensure exactly one YAML file is found; otherwise, raise an AssertionError with details
    assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
    
    # Return the path of the found YAML file
    return files[0]


def check_det_dataset(dataset, autodownload=True):
    """
    Download, verify, and/or unzip a dataset if not found locally.

    This function checks the availability of a specified dataset, and if not found, it has the option to download and
    unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
    resolves paths related to the dataset.

    Args:
        dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
        autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.

    Returns:
        (dict): Parsed dataset information and paths.
    """

    # Check if the dataset file exists locally and get its path
    file = check_file(dataset)

    # If the dataset file is a ZIP or TAR archive, download and unzip it if necessary
    extract_dir = ""
    if zipfile.is_zipfile(file) or is_tarfile(file):
        new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
        # Find and return the YAML file within the extracted directory
        file = find_dataset_yaml(DATASETS_DIR / new_dir)
        extract_dir, autodownload = file.parent, False

    # Load YAML data from the specified file, appending the filename to the loaded data
    data = yaml_load(file, append_filename=True)  # dictionary

    # Perform checks on the loaded YAML data
    for k in "train", "val":
        if k not in data:
            if k != "val" or "validation" not in data:
                # Raise a SyntaxError if required keys 'train' and 'val' (or 'validation') are missing
                raise SyntaxError(
                    emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
                )
            # Log a warning and rename 'validation' key to 'val' if necessary
            LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
            data["val"] = data.pop("validation")  # replace 'validation' key with 'val' key

    # Ensure 'names' or 'nc' keys are present in the data; otherwise, raise a SyntaxError
    if "names" not in data and "nc" not in data:
        raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))

    # Ensure the lengths of 'names' and 'nc' match if both are present
    if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
        raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
    # 如果数据字典中不存在键 "names",则创建一个名为 "names" 的列表,包含以"class_{i}"命名的元素,其中i从0到data["nc"]-1
    # 如果数据字典中已经存在 "names" 键,则将 "nc" 设置为 "names" 列表的长度
    if "names" not in data:
        data["names"] = [f"class_{i}" for i in range(data["nc"])]
    else:
        data["nc"] = len(data["names"])

    # 调用函数 check_class_names(),检查并修正 "names" 列表中的每个元素
    data["names"] = check_class_names(data["names"])

    # 解析和设置路径信息
    # path 变量根据 extract_dir、data["path"] 或者 data["yaml_file"] 的父路径创建,表示数据集的根路径
    path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent)  # dataset root
    if not path.is_absolute():
        path = (DATASETS_DIR / path).resolve()  # 如果路径不是绝对路径,则基于 DATASETS_DIR 设置绝对路径

    # 设置 data["path"] 为解析后的路径
    data["path"] = path  # download scripts

    # 对于 "train", "val", "test", "minival" 中的每个键,如果数据字典中存在该键,则将其路径设置为绝对路径
    for k in "train", "val", "test", "minival":
        if data.get(k):  # 如果该键存在
            if isinstance(data[k], str):
                # 如果路径是字符串类型,则基于 path 设置绝对路径
                x = (path / data[k]).resolve()
                # 如果路径不存在且以 "../" 开头,则修正路径
                if not x.exists() and data[k].startswith("../"):
                    x = (path / data[k][3:]).resolve()
                data[k] = str(x)
            else:
                # 如果路径是列表,则对列表中每个路径基于 path 设置绝对路径
                data[k] = [str((path / x).resolve()) for x in data[k]]

    # 解析 YAML 文件
    val, s = (data.get(x) for x in ("val", "download"))
    if val:
        # 如果存在 val,将其解析为绝对路径列表
        val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path
        # 如果存在某个路径不存在,则抛出 FileNotFoundError
        if not all(x.exists() for x in val):
            name = clean_url(dataset)  # 去除 URL 认证信息后的数据集名称
            # 构建错误信息字符串
            m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
            if s and autodownload:
                LOGGER.warning(m)
            else:
                m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
                raise FileNotFoundError(m)
            t = time.time()
            r = None  # 表示成功
            # 如果 s 是以 "http" 开头且以 ".zip" 结尾,则执行安全下载
            if s.startswith("http") and s.endswith(".zip"):  # URL
                safe_download(url=s, dir=DATASETS_DIR, delete=True)
            elif s.startswith("bash "):  # 如果 s 是以 "bash " 开头,则运行 bash 脚本
                LOGGER.info(f"Running {s} ...")
                r = os.system(s)
            else:  # 否则,执行 Python 脚本
                exec(s, {"yaml": data})
            dt = f"({round(time.time() - t, 1)}s)"
            # 根据执行结果设置日志消息
            s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
            LOGGER.info(f"Dataset download {s}\n")

    # 检查并下载字体文件,根据 "names" 是否只包含 ASCII 字符选择不同的字体文件进行下载
    check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf")  # download fonts

    return data  # 返回更新后的数据字典
    # 检查分类数据集,如Imagenet。

    # 如果 `dataset` 以 "http:/" 或 "https:/" 开头,尝试从网络下载数据集并保存到本地。
    # 如果 `dataset` 是以 ".zip", ".tar", 或 ".gz" 结尾的文件路径,检查文件的有效性后,下载并解压数据集到指定目录。

    # 将 `dataset` 转换为 `Path` 对象,并解析其绝对路径。
    dataset = Path(dataset)
    data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()

    # 如果指定路径的数据集不存在,尝试从网络下载。
    if not data_dir.is_dir():
        # 如果 `dataset` 是 "imagenet",执行特定的数据集下载脚本。
        # 否则,从 GitHub 发布的资源中下载指定的数据集压缩文件。
        LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
        t = time.time()
        if str(dataset) == "imagenet":
            subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
        else:
            url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip"
            download(url, dir=data_dir.parent)
        s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
        LOGGER.info(s)

    # 训练集的路径
    train_set = data_dir / "train"

    # 验证集的路径,优先选择 "val" 目录,其次选择 "validation" 目录,如果都不存在则为 None。
    val_set = (
        data_dir / "val"
        if (data_dir / "val").exists()
        else data_dir / "validation"
        if (data_dir / "validation").exists()
        else None
    )  # data/test or data/val

    # 测试集的路径,优先选择 "test" 目录,如果不存在则为 None。
    test_set = data_dir / "test" if (data_dir / "test").exists() else None  # data/val or data/test

    # 如果 `split` 参数为 "val",但验证集路径 `val_set` 不存在时,发出警告并使用测试集路径代替。
    if split == "val" and not val_set:
        LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
    
    # 如果 `split` 参数为 "test",但测试集路径 `test_set` 不存在时,发出警告并使用验证集路径代替。
    elif split == "test" and not test_set:
        LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")

    # 计算数据集中的类别数目,通过统计 `train` 目录下的子目录数量来得到。
    nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()])  # number of classes

    # 获取训练集中的类别名称列表,并按字母顺序排序后构建成字典,键为类别索引。
    names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()]  # class names list
    names = dict(enumerate(sorted(names)))

    # 打印结果到控制台
    # 遍历包含训练集、验证集和测试集的字典,每次迭代获取键值对(k为键,v为对应的数据集)
    for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
        # 使用f-string生成带颜色的前缀字符串,指示当前数据集的名称和状态
        prefix = f'{colorstr(f"{k}:")} {v}...'
        # 如果当前数据集为空(None),记录信息到日志
        if v is None:
            LOGGER.info(prefix)
        else:
            # 获取当前数据集中所有符合图像格式的文件路径列表
            files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
            # 计算当前数据集中的文件数目(nf)和不重复父目录数(nd)
            nf = len(files)  # 文件数目
            nd = len({file.parent for file in files})  # 不重复父目录数
            # 如果当前数据集中没有找到图像文件
            if nf == 0:
                # 如果是训练集,抛出文件未找到的错误并记录
                if k == "train":
                    raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
                else:
                    # 否则记录警告信息,指示没有找到图像文件
                    LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
            # 如果当前数据集中的类别数目与期望的类别数目不匹配
            elif nd != nc:
                # 记录警告信息,指示类别数目不匹配的错误
                LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
            else:
                # 记录信息,指示成功找到图像文件并且类别数目匹配
                LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")

    # 返回包含训练集、验证集、测试集、类别数和类别名称的字典
    return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
    """
    A class for generating HUB dataset JSON and `-hub` dataset directory.

    Args:
        path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'.
        task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
        autodownload (bool): Attempt to download dataset if not found locally. Default is False.

    Example:
        Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
            i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
        ```py
        from ultralytics.data.utils import HUBDatasetStats

        stats = HUBDatasetStats('path/to/coco8.zip', task='detect')  # detect dataset
        stats = HUBDatasetStats('path/to/coco8-seg.zip', task='segment')  # segment dataset
        stats = HUBDatasetStats('path/to/coco8-pose.zip', task='pose')  # pose dataset
        stats = HUBDatasetStats('path/to/dota8.zip', task='obb')  # OBB dataset
        stats = HUBDatasetStats('path/to/imagenet10.zip', task='classify')  # classification dataset

        stats.get_json(save=True)
        stats.process_images()
        ```
    """

    def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
        """Initialize class."""
        # Resolve the given path to its absolute form
        path = Path(path).resolve()
        # Log information message about starting dataset checks
        LOGGER.info(f"Starting HUB dataset checks for {path}....")

        # Initialize class attributes based on arguments
        self.task = task  # detect, segment, pose, classify

        # Depending on the task type, perform different operations
        if self.task == "classify":
            # Unzip the file and check the classification dataset
            unzip_dir = unzip_file(path)
            data = check_cls_dataset(unzip_dir)
            data["path"] = unzip_dir
        else:  # detect, segment, pose
            # Unzip the file, extract data directory and yaml path
            _, data_dir, yaml_path = self._unzip(Path(path))
            try:
                # Load YAML with checks
                data = yaml_load(yaml_path)
                # Strip path since YAML should be in dataset root for all HUB datasets
                data["path"] = ""
                yaml_save(yaml_path, data)
                # Perform dataset checks for detection dataset
                data = check_det_dataset(yaml_path, autodownload)  # dict
                # Set YAML path to data directory (relative) or parent (absolute)
                data["path"] = data_dir
            except Exception as e:
                # Raise an exception with a specific error message
                raise Exception("error/HUB/dataset_stats/init") from e

        # Set attributes for dataset directory and related paths
        self.hub_dir = Path(f'{data["path"]}-hub')
        self.im_dir = self.hub_dir / "images"
        # Create a statistics dictionary based on loaded data
        self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())}
        self.data = data
    # 解压缩指定路径的 ZIP 文件,并返回解压后的目录路径和数据集 YAML 文件路径
    def _unzip(path):
        """Unzip data.zip."""
        # 如果路径不是以 ".zip" 结尾,则认为是数据文件而非压缩文件,直接返回 False 表示未解压,以及原始路径
        if not str(path).endswith(".zip"):  # path is data.yaml
            return False, None, path
        # 调用 unzip_file 函数解压指定路径的 ZIP 文件到其父目录
        unzip_dir = unzip_file(path, path=path.parent)
        # 断言解压后的目录存在,否则输出错误信息,提示预期的解压路径
        assert unzip_dir.is_dir(), (
            f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
        )
        # 返回 True 表示成功解压,解压后的目录路径字符串,以及在解压目录中找到的数据集 YAML 文件路径
        return True, str(unzip_dir), find_dataset_yaml(unzip_dir)  # zipped, data_dir, yaml_path

    # 保存压缩后的图像用于 HUB 预览
    def _hub_ops(self, f):
        """Saves a compressed image for HUB previews."""
        # 调用 compress_one_image 函数,将指定文件 f 压缩保存到 self.im_dir 目录下,使用文件名作为保存的文件名
        compress_one_image(f, self.im_dir / Path(f).name)  # save to dataset-hub

    # 处理图像,为 Ultralytics HUB 压缩图像
    def process_images(self):
        """Compress images for Ultralytics HUB."""
        from ultralytics.data import YOLODataset  # ClassificationDataset

        # 创建目录 self.im_dir,如果不存在则创建,用于保存压缩后的图像文件
        self.im_dir.mkdir(parents=True, exist_ok=True)  # makes dataset-hub/images/
        
        # 遍历 "train", "val", "test" 三个数据集分割
        for split in "train", "val", "test":
            # 如果 self.data 中不存在当前分割的数据集,则跳过继续下一个分割
            if self.data.get(split) is None:
                continue
            # 创建 YOLODataset 对象,指定图像路径为 self.data[split],数据为 self.data
            dataset = YOLODataset(img_path=self.data[split], data=self.data)
            # 使用线程池 ThreadPool,并发处理图像压缩操作
            with ThreadPool(NUM_THREADS) as pool:
                # 使用 TQDM 显示进度条,遍历数据集中的图像文件,对每个图像文件调用 _hub_ops 方法进行压缩保存操作
                for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
                    pass
        # 输出日志信息,指示所有图像保存到 self.im_dir 目录下完成
        LOGGER.info(f"Done. All images saved to {self.im_dir}")
        # 返回保存压缩图像的目录路径
        return self.im_dir
# 自动将数据集分割成训练集/验证集/测试集,并将结果保存到autosplit_*.txt文件中
def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
    """
    Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.

    Args:
        path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
        weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
        annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.

    Example:
        ```py
        from ultralytics.data.utils import autosplit

        autosplit()
        ```
    """

    path = Path(path)  # 图像目录的路径
    # 筛选出所有符合图片格式的文件,以列表形式存储在files中
    files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS)  # 只保留图片文件
    n = len(files)  # 文件总数
    random.seed(0)  # 设置随机种子以便复现结果
    # 根据权重随机分配每个图片到训练集、验证集或测试集,k=n表示生成n个随机数
    indices = random.choices([0, 1, 2], weights=weights, k=n)  # 将每个图片分配到相应的集合中

    # 定义三个txt文件名,分别用于存储训练集、验证集、测试集的文件列表
    txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"]
    # 如果文件已存在,则先删除
    for x in txt:
        if (path.parent / x).exists():
            (path.parent / x).unlink()  # 删除已存在的文件

    # 输出信息,指示正在对图像进行自动分割处理,并显示是否只使用有标签的图像文件
    LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
    # 使用 tqdm 迭代处理索引和文件列表 zip(indices, files),总数为 n,同时显示进度条
    for i, img in TQDM(zip(indices, files), total=n):
        # 如果 annotated_only 为 False 或者对应图片的标签文件存在,则执行下面的操作
        if not annotated_only or Path(img2label_paths([str(img)])[0]).exists():  # 检查标签文件是否存在
            # 以追加模式打开路径 path.parent / txt[i] 对应的文件,并写入当前图片路径
            with open(path.parent / txt[i], "a") as f:
                # 将当前图片相对于 path.parent 的路径作为 POSIX 路径添加到文本文件中,并换行
                f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n")
def load_dataset_cache_file(path):
    """Load an Ultralytics *.cache dictionary from path."""
    import gc  # 导入垃圾回收模块

    gc.disable()  # 禁用垃圾回收,以减少反序列化时间 https://github.com/ultralytics/ultralytics/pull/1585
    cache = np.load(str(path), allow_pickle=True).item()  # 加载字典对象
    gc.enable()  # 启用垃圾回收
    return cache  # 返回加载的缓存数据


def save_dataset_cache_file(prefix, path, x, version):
    """Save an Ultralytics dataset *.cache dictionary x to path."""
    x["version"] = version  # 添加缓存版本信息
    if is_dir_writeable(path.parent):  # 检查父目录是否可写
        if path.exists():
            path.unlink()  # 如果文件已存在,则删除 *.cache 文件
        np.save(str(path), x)  # 将缓存保存到文件中以便下次使用
        path.with_suffix(".cache.npy").rename(path)  # 移除 .npy 后缀
        LOGGER.info(f"{prefix}New cache created: {path}")  # 记录日志,显示创建了新的缓存文件
    else:
        LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
        # 记录警告日志,显示缓存目录不可写,未保存缓存信息
posted @ 2024-09-05 11:58  绝不原创的飞龙  阅读(3)  评论(0编辑  收藏  举报