如何查找Model的state_dict和ckpt的state_dict之间的差距
参考资料:
[自己摸索]
[chatgpt3.5]
众所周知,Huggingface团队的transformers库是一个非常优秀非常方便的库,它使得很多模型实现了“开箱即用”。但是,由于transformers这个库的快速迭代,也导致了很多兼容性上的问题。比如今天我发现一个现象:我使用老板的transformers库(4.19)去加载FrozenCLIPEmbedder这个组件,并且加载了SD1.5的ckpt中的权重,是没有任何问题的。但是当我使用更新的版本的transformers库(4.34)去加载FrozenCLIPEmbedder这个组件,并且想加载SD1.5的ckpt中的权重时,却会发现一个报错,如下所示:
报错也是非常容易理解,就是ckpt文件中的state_dict出现了当前版本模型没有的key_word:position_ids。分析了一下我认为有可能是position_ids这个东西可能是多余的,因此在后面版本的transformers库中官方就把这个东西给丢掉了。为了验证这个猜想,我读了一下ckpt中的对应值,果然是一个常量(就是0~76嘛!):
有经验的读者肯定已经想到解决办法了,在load_state_dict的时候加上一个option不就行了么?strict=False,这样model在加载state_dict的时候就不会强制要求一一对应了。这样做当然是没有错的,但是这次可以这样绕过,下一次如果再次出现大面积的key不匹配该怎么办?有没有一个好的工具性的代码片能够帮助我们检视?这样我们就可以做一些骚操作比如对state_dict中内的字典关键字重命名从而使得权重互相匹配上。
下面请出我们的chatgpt3.5老师,这是它写出的代码片段:
import torch def compare_state_dicts(ckpt_path, model_state_dict): ckpt_state_dict = torch.load(ckpt_path, map_location='cpu')['state_dict'] print("Keys present only in ckpt state_dict:") for key in ckpt_state_dict.keys(): if key not in model_state_dict: print(key) print("\nKeys present only in model state_dict:") for key in model_state_dict.keys(): if key not in ckpt_state_dict: print(key) print("\nKeys with different values:") for key in model_state_dict.keys(): if key in ckpt_state_dict and not torch.equal(model_state_dict[key], ckpt_state_dict[key]): print(key) # Example usage # Load model state_dict model = YourModel() model_state_dict = model.state_dict() # Provide ckpt path ckpt_path = 'path_to_ckpt.pth' # Compare state_dicts compare_state_dicts(ckpt_path, model_state_dict)
Keys with different values这个判断显然是不需要的,其他地方写的倒是没有什么问题。我们运行一下试试看~
事实证明从ckpt中加载text_encoder这一步完全是脱裤子放屁。因为在初始化FrozenCLIPEmbedder的时候已经加载了clip-vit-large-patch14中的权重,在训练SD1.5的时候又完全没有动这一部分的权重,那么除了position_ids这个参数关键字被删去之外,就完全没有差异了。