【代码】LLaVA 代码运行记录
LLaVA代码地址:https://github.com/haotian-liu/LLaVA
LLaVolta代码地址:https://github.com/Beckschen/LLaVolta
最近在做 LLaVA 和 LLaVolta 的学习与改进,开个贴记录一下遇到的问题以及解决方案。
LLaVA
安装 flash-attn
通过nvcc -V
可以查看版本,需要 >=11.7,否则会报错如下:
解决方案:在终端运行如下两条指令重新安装
export PATH=/usr/local/cuda/bin:$PATH export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
如果还是不行可以直接去github仓库下载对应版本whl文件:https://github.com/Dao-AILab/flash-attention/releases,其中对应:cuda版本、torch版本、python(cp)版本。下载好之后直接pip install xxx.whl即可。
Deepspeed 使用
由于服务器的原因,我的 conda 并不是在默认的 ~/ 目录中,因此每次使用 pip
和 deepspeed
需要额外指定,每个.sh
文件需要明确 deepspeed 路径,否则找不到transformers
包。
解决方案:更改使用deepspeed的命令
/home/data/shika/miniconda3/envs/llava/bin/deepspeed llava/train/train_mem.py
wandb使用
wandb库需要申请 api 并登录,按照这个教程即可成功登录。
多卡训练端口
如果多卡训练,需要制定一个端口,刚好默认的29500被占用了:
解决方案:手动重新指定端口
在.sh脚本中加入参数--master_port <port>
,注意一定要在train_mem.py之前,否则无法识别就会报错。
LLaVA-Bench-Wild 数据集评估
直接使用会报错
Error communicating with OpenAI: ('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))
解决方案:手动降级 urllib3
pip install urllib3==1.26.11
LLaVolta
LLaVolta 基本是基于 LLaVA 进行改造的,因此运行方式高度相似(基本相同)。
包的版本
按照README.md
会安装 peft==0.14.0
和transfromers==4.37.2
,然后这个错误困扰了我很长时间:
- peft==0.14.0 + transformers 最新版本 --> TypeError: LlamaRotaryEmbedding.forward() got an unexpected keyword argument 'seq_len'
- peft==0.14.0 + transformers==4.37.2 --> cannot import name 'EncoderDecodercache' from 'transformers'
解决方案:必须用 peft==0.10.0 + transformers==4.37.2
本文作者:KeanShi
本文链接:https://www.cnblogs.com/keanshi/p/18602710
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步