【代码】LLaVA 代码学习与调试

代码地址:https://github.com/haotian-liu/LLaVA
感谢 b 站 @良睦路程序员 的教程

LLaVA 是一个端到端训练的大型多模态模型,旨在根据视觉输入(图像)和文本指令理解和生成内容。已经成为目前最基本、影响力最广的多模态大语言模型之一;恰好最近正在学习 LLaVA 整体结构,特此开帖记录在调试中学习代码的过程。

Prepare

本人使用 VsCode + macOS 进行代码调试,可以直接使用官方给出的 demo;

from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

prompt = ""
url = ""
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=prompt, images=image, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_length=600)
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

print(response)

注意,由于我们真正调试的是LlavaForConditionalGenerationforward 函数,所以要在launch.json中手动添加一行:

"justMyCode": false

最后在forward函数第一行打上断点即可开始调试。

Code review

注:所有记录均基于pretrain.sh的默认参数,如果改过可能会有不同
不会 LLaVA 的同学可以点击这里学习

image

纯文本

用户输入一句指令后,会有指定的 tokenizer 对其进行切分成 input_ids,LLaVA 使用的是 transformers 库中的 AutoTokenizer,在 train.py 中可以找到:

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=False,
)

LLaVA使用的LLaVATrainer继承自Trainer,声明如下,可以看到传入了tokenizer

trainer = LLaVATrainer(model=model,
    tokenizer=tokenizer,
    args=training_args,
    **data_module)

接下来是模型对于 input_ids 的处理,也就是 LlavaLlamaForCausalLM 类,在
llava/model/language_model/llava_llama.py 中定义如下:

class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig
    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

其继承自 LlamaForCausalLMLlavaMetaForCausalLM,这里我们想看 tokens 处理,因此找 LLama 相关的即可。点击跳转后可以发现其定义如下:

class LlamaForCausalLM(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]
    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

发现这里使用的为LlamaModel,继续点击跳转:

class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """
    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self._use_sdpa = config._attn_implementation == "sdpa"
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

其中 text embedding 对应的为第14行的定义,即 LLaVA 使用的 nn.Embedding 进行文本嵌入,在 forward 方法中,inputs_embeds 即为词嵌入的 tokens:

if inputs_embeds is None:
    inputs_embeds = self.embed_tokens(input_ids)

注:后续 hidden_states 会初始化为 inputs_embeds ,经过多层 transformer,自回归式地进行训练。

多模态处理

该部分是LLaVA的重点创新部分,整体逻辑为将 image token 和 text token 拼接起来一起送入LLM中。
首先看LLaVA的核心模型类 LlavaLlamaForCausalLM,由于是多模态处理,需要关注继承的 LlavaMetaForCausalLM 类:

class LlavaMetaForCausalLM(ABC):

    @abstractmethod
    def get_model(self):
        pass

    def get_vision_tower(self):
        return self.get_model().get_vision_tower()

    def encode_images(self, images):
        image_features = self.get_model().get_vision_tower()(images)
        image_features = self.get_model().mm_projector(image_features)
        return image_features

其中包含了对于 vision 模态的处理,而其精髓部分在于 prepare_inputs_labels_for_multimodal(也是整个 LLaVA 的核心)。我们重点关注最终返回的 new_input_embeds,跟踪其生成过程:

  1. 处理没有图像的单模态数据时:
if num_images == 0:
    cur_image_features = image_features[cur_image_idx]
    cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
    cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
    new_input_embeds.append(cur_input_embeds)
    new_labels.append(labels[batch_idx])
    cur_image_idx += 1
    continue
  • cur_input_embeds_1 是通过模型的 embed_tokens 方法将 cur_input_ids 转换为文本嵌入。
  • cur_image_features[0:0] 是一个空的张量,用于占位(因为没有图像)。
  • 将文本嵌入和空的图像嵌入拼接在一起,得到 cur_input_embeds。
  • 将 cur_input_embeds 添加到 new_input_embeds 中,并将对应的标签添加到 new_labels 中。
  1. 处理有图像的多模态数据时:
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
    cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
    cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
  • 如果当前 batch 中有图像 token,首先找到所有图像 token 的位置(image_token_indices)。
  • 将 cur_input_ids 和 cur_labels 按照图像 token 的位置分割成多个部分,分别存储到 cur_input_ids_noim 和 cur_labels_noim 中。
  1. 文本嵌入,继续沿用之前写的embed_tokens层:
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
  • 将分割后的 cur_input_ids_noim 拼接在一起,并通过 embed_tokens 方法生成文本嵌入 cur_input_embeds。
  • 将 cur_input_embeds 按照 split_sizes 重新分割,得到 cur_input_embeds_no_im,即每个部分的文本嵌入。
  1. 图像的features生成

这部分 LLaVA 分为了多张图和单张图的 encode,多张图需要 concat 在一起,会考虑不同图片的空间、位置信息,encode 完还需要 split 到原始位置,但核心为使用 self.encode_images 方法进行编码:

image_features = self.encode_images(images)

该过程定义如下,使用的是训练好的ViT,也没有进行BP参数更新:

def encode_images(self, images):
    image_features = self.get_model().get_vision_tower()(images)
    image_features = self.get_model().mm_projector(image_features)
    return image_features
  1. 整合图像和文本token
cur_new_input_embeds = []
cur_new_labels = []

for i in range(num_images + 1):
    cur_new_input_embeds.append(cur_input_embeds_no_im[i])
    cur_new_labels.append(cur_labels_noim[i])
    if i < num_images:
        cur_image_features = image_features[cur_image_idx]
        cur_image_idx += 1
        cur_new_input_embeds.append(cur_image_features)
        cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
        
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)

new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
  • 将刚刚分离的文本和图像对应的token重新整合在一起,即 cur_new_input_embeds。
  • 将 cur_new_input_embeds 中的文本嵌入和图像嵌入拼接在一起,得到当前 batch 的最终嵌入。
  • 注:最后还需要维度扩展到相同,即 \(batch\_size \times seq\_len \times embed\_len\)
posted @ 2024-12-20 14:48  KeanShi  阅读(4)  评论(0编辑  收藏  举报