huggingface 下载与训练模型时会报 SSLERROR 连接 错误,解决办法如下

我用方案一解决

解决方案

方案1:使用代理(需要梯子)

在你的 Python 代码的开头加上如下代码

import os
os.environ['HTTP_PROXY'] = 'http://proxy_ip_address:port'
os.environ['HTTPS_PROXY'] = 'http://proxy_ip_address:port'

其中 http://proxy_ip_address:port 中的 proxy_ip_address 和 port为开启梯子后

(windows)设置>网络和Internet>代理>手动设置代理>编辑代理服务器

中的代理IP地址和端口

代理IP地址:端口

例如在我的情况下就是

import os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

 

方案2:本地下载

进入 huggingface 找到自己想要的预训练模型,以 resnet34 为例,下面是 resnet34 在 huggingface 的仓库

timm/resnet34.a1_in1k

从仓库中下载 model.safetensors 或者 pytorch_model.bin 文件

import timm
model = timm.create_model(
  'resnet34',
  pretrained=True,
  pretrained_cfg_overlay=dict(file=r'path\to\checkpoint'),
)

在调用 timm.create_model 时传入 pretrained_cfg_overlay 参数

其中 checkpoint 可以是 *.safetensors*.bin*.pth*.pt*.ckpt 等格式的存储模型权重的文件。

在传入 pretrained_cfg_overlay=dict(file=r'path\to\checkpoint') 参数后,默认的 pretrained_cfg 预训练 config 中会添加 file=r'path\to\checkpoint 键值对,导入模型权重时,代码会优先检查 config 中是否有 file 关键词,代码会优先从 file 中导入模型权重。

参见 Github 中 timm 源码:

load_from == 'file' | timm source code

关于导入config.json

从huggingface仓库中下载模型对应的config.json文件,按照下列方式传入参数

import json
import timm

path2cfg = r'path\to\config.json'
path2mdl = r'path\to\model.safetensors'
with open(path2cfg, "r", encoding="utf-8") as reader:
    text = reader.read()
    cfg_dict = json.loads(text)

model = timm.create_model(
  'resnet34.a1_in1k',
  pretrained=True,
  pretrained_cfg=cfg_dict['pretrained_cfg'],
  pretrained_cfg_overlay=dict(file=path2mdl),
)
posted @ 2024-05-29 14:06  gds111789  阅读(156)  评论(0编辑  收藏  举报