Diffusers中Pipeline的数据类型是怎么设置和转化的,pipeline.dtype和pipeline.from_pretrained(torch_dtype)
参考资料:
Diffusers中DiffusionPipeline基类的[源码]
众所周知Pipeline是Diffusers中最重要的一个API接口,一直以来我都对这个接口数据结构的获取一知半解,今天看了下源码终于知道了这个API结构的数据类型是如何设置的。直接看代码:
@property def dtype(self) -> torch.dtype: r""" Returns: `torch.dtype`: The torch dtype on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: return module.dtype return torch.float32
这里使用了一个@property的语法糖,使用该语法糖包裹住的方法能直接作为类的属性。即调用pipeline.type就会运行该函数,获得返回值。
代码也是一目了然的,首先从自己注册了的参数列表里面去找模块,一旦发现torch.nn.Module类型的立刻返回,否则则直接返回默认值torch.float32。关于这个代码有两个相关需要注意的地方要说
1. dtype如何初始化
从dtype的方法体可以看出,dtype在初始化pipeline的时候不需要特别去设置,只要pipeline包含module,就一定会返回一个数据类型。而module的初始数据类型则往往是from_pretrained方法来定义的:即pipeline.from_pretrained(torch_dtype=torch.float16/float32)。
2. 人工设置参数
经常玩pipeline的朋友可能会手动设置pipeline中的模块,即在pipeline.from_pretrained的时候对模块进行手动赋值,但是这种做法就带来了一个问题,有可能导致pipeline中的数据类型不统一,因此我觉得一个比较好的做法是在参数列表中显式指定torch_dtype或者在初始化结束之后添加to(dtype)。