【杂学】LLaVA 代码运行记录

LLaVA代码地址:https://github.com/haotian-liu/LLaVA
LLaVolta代码地址:https://github.com/Beckschen/LLaVolta

最近在做 LLaVA 和 LLaVolta 的学习与改进,开个贴记录一下遇到的问题以及解决方案。

LLaVA

安装 flash-attn

通过nvcc -V可以查看版本,需要 >=11.7,否则会报错如下:

image

解决方案:在终端运行如下两条指令重新安装

export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH

Deepspeed 使用

由于服务器的原因,我的 conda 并不是在默认的 ~/ 目录中,因此每次使用 pipdeepspeed 需要额外指定,每个.sh 文件需要明确 deepspeed 路径,否则找不到transformers包。

解决方案:更改使用deepspeed的命令

/home/data/shika/miniconda3/envs/llava/bin/deepspeed llava/train/train_mem.py 

wandb使用

wandb库需要申请 api 并登录,按照这个教程即可成功登录。

多卡训练端口

如果多卡训练,需要制定一个端口,刚好默认的29500被占用了:
image

解决方案:手动重新指定端口

export MASTER_PORT=29400  # 替换为你需要的端口

LLaVolta

LLaVolta 基本是基于 LLaVA 进行改造的,因此运行方式高度相似(基本相同)。

包的版本

按照README.md会安装 peft==0.14.0transfromers==4.37.2,然后这个错误困扰了我很长时间:

  1. peft==0.14.0 + transformers 最新版本 --> TypeError: LlamaRotaryEmbedding.forward() got an unexpected keyword argument 'seq_len'
  2. peft==0.14.0 + transformers==4.37.2 --> cannot import name 'EncoderDecodercache' from 'transformers'

解决方案:必须用 peft==0.10.0 + transformers==4.37.2

posted @ 2024-12-12 15:52  KeanShi  阅读(31)  评论(0编辑  收藏  举报