Diffusers中Pipeline的数据类型是怎么设置和转化的,pipeline.dtype和pipeline.from_pretrained(torch_dtype)
参考资料:
Diffusers中DiffusionPipeline基类的[源码]
众所周知Pipeline是Diffusers中最重要的一个API接口,一直以来我都对这个接口数据结构的获取一知半解,今天看了下源码终于知道了这个API结构的数据类型是如何设置的。直接看代码:
@property def dtype(self) -> torch.dtype: r""" Returns: `torch.dtype`: The torch dtype on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: return module.dtype return torch.float32
这里使用了一个@property的语法糖,使用该语法糖包裹住的方法能直接作为类的属性。即调用pipeline.type就会运行该函数,获得返回值。
代码也是一目了然的,首先从自己注册了的参数列表里面去找模块,一旦发现torch.nn.Module类型的立刻返回,否则则直接返回默认值torch.float32。关于这个代码有两个相关需要注意的地方要说
1. dtype如何初始化
从dtype的方法体可以看出,dtype在初始化pipeline的时候不需要特别去设置,只要pipeline包含module,就一定会返回一个数据类型。而module的初始数据类型则往往是from_pretrained方法来定义的:即pipeline.from_pretrained(torch_dtype=torch.float16/float32)。
2. 人工设置参数
经常玩pipeline的朋友可能会手动设置pipeline中的模块,即在pipeline.from_pretrained的时候对模块进行手动赋值,但是这种做法就带来了一个问题,有可能导致pipeline中的数据类型不统一,因此我觉得一个比较好的做法是在参数列表中显式指定torch_dtype或者在初始化结束之后添加to(dtype)。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)