如何查找Model的state_dict和ckpt的state_dict之间的差距

  参考资料:

  [自己摸索]

  [chatgpt3.5]

  众所周知,Huggingface团队的transformers库是一个非常优秀非常方便的库,它使得很多模型实现了“开箱即用”。但是,由于transformers这个库的快速迭代,也导致了很多兼容性上的问题。比如今天我发现一个现象:我使用老板的transformers库(4.19)去加载FrozenCLIPEmbedder这个组件,并且加载了SD1.5的ckpt中的权重,是没有任何问题的。但是当我使用更新的版本的transformers库(4.34)去加载FrozenCLIPEmbedder这个组件,并且想加载SD1.5的ckpt中的权重时,却会发现一个报错,如下所示:

  报错也是非常容易理解,就是ckpt文件中的state_dict出现了当前版本模型没有的key_word:position_ids。分析了一下我认为有可能是position_ids这个东西可能是多余的,因此在后面版本的transformers库中官方就把这个东西给丢掉了。为了验证这个猜想,我读了一下ckpt中的对应值,果然是一个常量(就是0~76嘛!):

  有经验的读者肯定已经想到解决办法了,在load_state_dict的时候加上一个option不就行了么?strict=False,这样model在加载state_dict的时候就不会强制要求一一对应了。这样做当然是没有错的,但是这次可以这样绕过,下一次如果再次出现大面积的key不匹配该怎么办?有没有一个好的工具性的代码片能够帮助我们检视?这样我们就可以做一些骚操作比如对state_dict中内的字典关键字重命名从而使得权重互相匹配上。

  下面请出我们的chatgpt3.5老师,这是它写出的代码片段:

import torch

def compare_state_dicts(ckpt_path, model_state_dict):
    ckpt_state_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
    
    print("Keys present only in ckpt state_dict:")
    for key in ckpt_state_dict.keys():
        if key not in model_state_dict:
            print(key)
    
    print("\nKeys present only in model state_dict:")
    for key in model_state_dict.keys():
        if key not in ckpt_state_dict:
            print(key)
    
    print("\nKeys with different values:")
    for key in model_state_dict.keys():
        if key in ckpt_state_dict and not torch.equal(model_state_dict[key], ckpt_state_dict[key]):
            print(key)

# Example usage

# Load model state_dict
model = YourModel()
model_state_dict = model.state_dict()

# Provide ckpt path
ckpt_path = 'path_to_ckpt.pth'

# Compare state_dicts
compare_state_dicts(ckpt_path, model_state_dict)

  Keys with different values这个判断显然是不需要的,其他地方写的倒是没有什么问题。我们运行一下试试看~

  事实证明从ckpt中加载text_encoder这一步完全是脱裤子放屁。因为在初始化FrozenCLIPEmbedder的时候已经加载了clip-vit-large-patch14中的权重,在训练SD1.5的时候又完全没有动这一部分的权重,那么除了position_ids这个参数关键字被删去之外,就完全没有差异了。

posted @ 2023-10-02 15:51  思念殇千寻  阅读(69)  评论(0编辑  收藏  举报