如何加载本地下载下来的BERT模型,pytorch踩坑!!

近期做实验频繁用到BERT,所以想着下载下来使用,结果各种问题,网上一搜也是简单一句:xxx.from_pretrained("改为自己的路径")
我只想说,大坑!!!
废话不多说:

1.下载模型文件:

不管你是从hugging-face还是哪里下载来的模型(pytorch版)文件夹,应该包含以下三个文件:

  • config.json
  • vocab.txt
  • pytorch_model.bin

具体都是什么内容,不做介绍,你也不需要知道

2.更改文件名!!(坑点1)

很多下载的模型文件夹里面上述三个文件名字可能会有不同,一定要注意!以清华OpenCLaP上下载下来的民事BERT为例,其中包含了三个文件对应的名字为:

  • bert_config.json 看到没有!!这个前面多了个bert_,一定要改掉!bert_config.json
  • vocab.txt
  • pytorch_model.bin

三个文件一定要与第一步中的结构一样,名字也必须一样

3.将文件放入自己的文件夹

这里我们在自己的工程目录里新建一个文件夹:bert_localpath,将三个文件放入其中,最终结构如下:

bert_localpath

config.json
vocab.txt
pytorch_model.bin

4.加载(坑点2)

使用 .from_pretrained("xxxxx")方法加载,本地加载bert需要修改两个地方,一是tokenizer部分,二是model部分:
step1、导包: from transformers import BertModel,BertTokenizer
step2、载入词表: tokenizer = BertTokenizer.from_pretrained("./bert_localpath/") 这里要注意!!除了你自己建的文件夹名外,后面一定要加个/,才能保证该方法找到你的vocab.txt
step3、载入模型: bert = BertModel.from_pretrained("./bert_localpath") 然后,这个地方又不需要加上/

5.使用

至此,你就能够使用你的本地bert了!!例如~outputs = bert(input_ids, token_type_ids, attention_mask)来获得token的编码输出output

over,网上很多教程对小白很不友好,记录一下自己的踩坑,希望能帮到你,如果觉得我写的有问题的或者太简单的,可以去看看其他人的

posted @ 2022-01-28 12:03  ZhangHT97  阅读(17676)  评论(4编辑  收藏  举报