Diffusers中Pipeline的数据类型是怎么设置和转化的,pipeline.dtype和pipeline.from_pretrained(torch_dtype)

  参考资料:

  Diffusers中DiffusionPipeline基类的[源码]

  众所周知Pipeline是Diffusers中最重要的一个API接口,一直以来我都对这个接口数据结构的获取一知半解,今天看了下源码终于知道了这个API结构的数据类型是如何设置的。直接看代码:

@property
def dtype(self) -> torch.dtype:
    r"""
    Returns:
        `torch.dtype`: The torch dtype on which the pipeline is located.
    """
    module_names, _ = self._get_signature_keys(self)
    modules = [getattr(self, n, None) for n in module_names]
    modules = [m for m in modules if isinstance(m, torch.nn.Module)]

    for module in modules:
        return module.dtype

    return torch.float32

  这里使用了一个@property的语法糖,使用该语法糖包裹住的方法能直接作为类的属性。即调用pipeline.type就会运行该函数,获得返回值。

  代码也是一目了然的,首先从自己注册了的参数列表里面去找模块,一旦发现torch.nn.Module类型的立刻返回,否则则直接返回默认值torch.float32。关于这个代码有两个相关需要注意的地方要说

  1. dtype如何初始化

  从dtype的方法体可以看出,dtype在初始化pipeline的时候不需要特别去设置,只要pipeline包含module,就一定会返回一个数据类型。而module的初始数据类型则往往是from_pretrained方法来定义的:即pipeline.from_pretrained(torch_dtype=torch.float16/float32)。

  2. 人工设置参数

  经常玩pipeline的朋友可能会手动设置pipeline中的模块,即在pipeline.from_pretrained的时候对模块进行手动赋值,但是这种做法就带来了一个问题,有可能导致pipeline中的数据类型不统一,因此我觉得一个比较好的做法是在参数列表中显式指定torch_dtype或者在初始化结束之后添加to(dtype)。

posted @ 2024-08-08 17:04  思念殇千寻  阅读(49)  评论(0编辑  收藏  举报