python系列&deep_study系列:模型下载的几种方式
模型下载的几种方式
问题描述
作为一名自然语言处理算法人员,hugging face开源的transformers包在日常的使用十分频繁。在使用过程中,每次使用新模型的时候都需要进行下载。如果训练用的服务器有网,那么可以通过调用from_pretrained方法直接下载模型。但是就本人的体验来看,这种方式尽管方便,但还是会有两方面的问题:
-
如果网络很不好,模型下载时间会很久,一个小模型下载几个小时也很常见
-
如果换了训练服务器,又要重新下载。
这里可能大家会疑惑,为什么不能把当前下载好的模型迁移过去,我们可以看下通过from_pretrained保存的文件(一般在~/.cache/huggingface/transformers文件夹下)模型文件
!https://s3-us-west-2.amazonaws.com/secure.notion-static.com/79042590-35ff-4181-9c70-1db5bf713183/v2-6a9100687e302faffa91950ac21102f1_720w.jpg
推荐方式
transformers下载 推荐
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
model_name = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_name )
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name )
Hugging Face Hub 下载 推荐
pip install huggingface_hub
from huggingface_hub import snapshot_download
snapshot_download(repo_id="bert-base-chinese")
# allow_regex和ignore_regex两个参数,简单来说前者是对指定的匹配项进行下载,后者是忽略指定的匹配项,下载其余部分
snapshot_download(repo_id="bert-base-chinese", ignore_regex=["*.h5", "*.ot", "*.msgpack"])
requests 下载
import os
import json
import requests
from uuid import uuid4
from tqdm import tqdm
SESSIONID = uuid4().hex
VOCAB_FILE = "vocab.txt"
CONFIG_FILE = "config.json"
MODEL_FILE = "pytorch_model.bin"
BASE_URL = "https://huggingface.co/{}/resolve/main/{}"
headers = {'user-agent': 'transformers/4.8.2; python/3.8.5; \
session_id/{}; torch/1.9.0; tensorflow/2.5.0; \
file_type/model; framework/pytorch; from_auto_class/False'.format(SESSIONID)}
model_id = "bert-base-chinese"
# 创建模型对应的文件夹
model_dir = model_id.replace("/", "-")
if not os.path.exists(model_dir):
os.mkdir(model_dir)
# vocab 和 config 文件可以直接下载
r = requests.get(BASE_URL.format(model_id, VOCAB_FILE), headers=headers)
r.encoding = "utf-8"
with open(os.path.join(model_dir, VOCAB_FILE), "w", encoding="utf-8") as f:
f.write(r.text)
print("{}词典文件下载完毕!".format(model_id))
r = requests.get(BASE_URL.format(model_id, CONFIG_FILE), headers=headers)
r.encoding = "utf-8"
with open(os.path.join(model_dir, CONFIG_FILE), "w", encoding="utf-8") as f:
json.dump(r.json(), f, indent="\t")
print("{}配置文件下载完毕!".format(model_id))
# 模型文件需要分两步进行
# Step1 获取模型下载的真实地址
r = requests.head(BASE_URL.format(model_id, MODEL_FILE), headers=headers)
r.raise_for_status()
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
# Step2 请求真实地址下载模型
r = requests.get(url_to_download, stream=True, proxies=None, headers=None)
r.raise_for_status()
# 这里的进度条是可选项,直接使用了transformers包中的代码
content_length = r.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(
unit="B",
unit_scale=True,
total=total,
initial=0,
desc="Downloading Model",
)
with open(os.path.join(model_dir, MODEL_FILE), "wb") as temp_file:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
print("{}模型文件下载完毕!".format(model_id))
Git LFS 下载
准备工作
Git LFS
的方案相较于前面自行实现的方案要简洁的多得多。我们需要在安装git
的基础上,再安装git lfs
。以Windows
为例,命令如下
git lfs install
模型下载
我们还是以bert-base-chinese
为例进行下载,打开具体的模型面,可以看到右上角有一个Use in Transformers
的button
。
点击该Button
,我们就可以看到具体的下载命令了。
拷贝命令在终端执行,就可以下载了。下载后的格式,和前面自行实现的代码是一样,但是就使用体验上来看,这种方式明显会更加优雅!
但是,这种方案也存在着一定的问题,即会下载仓库中的所有文件,会大大延长模型下载的时间。我们可以看到在目录中包含着flax_model.msgpack
、tf_model.h5
和pytorch_model.bin
三个不同框架模型文件,在bert-base-uncased
的版本中,还存在着rust版本
的rust_model.ot模型
,如果我们只想要一个版本的模型文件,这种方案就无法实现了。
如果想实现模型精确下载,我们还可以借助Hugging Face Hub
,下面来介绍这种方案。