用断点调试阅读peft源码:prefix tuning
今天我们阅读peft源码,主要是为了弄清楚prefix tuning的工作原理和代码细节。
理解和思考
(1) prefix tuning和zero-shot的区别在于,把指令/要求(比如要生成positive的句子)和输入的文字直接区分开,指令用连续向量而不是离散词元表示。如果不是prefix tuning,那么需要用明确的语言做prompt engineering,比如:
要求:生成积极的句子。
主题:运动
例子:游泳有益于身心健康
开始:
提示工程不需要训练,但这样做文字游戏很难调出来。这就凸显了prefix tuning的好处,直接用连续向量代替这些指令,并编码指令与输入的关系、以及对输出的指导作用。又或者,指令已经在输入中写得很明确,但prefix tuning在zero-shot的基础上进一步强化了这些指令对输出的指导作用。
感觉如果把prefix写成文字,加到输入前面,然后只训练prefix的参数,好像也能达到同样的效果。但是prefix的长度就不是固定了!由此可见,越是复杂的任务越需要更长的prefix~
(2) prefix tuning和finetuning的区别在于只训练prefix部分的参数~
模型定义部分
peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)
# 下载预训练模型T5,模型结构可以在debug console中输入model得到
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
主要是这一句:model = get_peft_model(model, peft_config)
,所以在这里设置断点。
首先跳转到:peft->mapping.py->get_peft_model函数。我逐行阅读并做出中文注释。
def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel:
"""
Returns a Peft model object from a model and a config.
Args:
model ([`transformers.PreTrainedModel`]): Model to be wrapped.
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
"""
model_config = getattr(model, "config", {"model_type": "custom"}) # 得到T5模型config,在debug console中输入model_config可以查看
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict() #把config中的属性序列化为 Python 字典
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
# <TaskType.SEQ_2_SEQ_LM: 'SEQ_2_SEQ_LM'>
# dict_keys(['SEQ_CLS', 'SEQ_2_SEQ_LM', 'CAUSAL_LM', 'TOKEN_CLS', 'QUESTION_ANS', 'FEATURE_EXTRACTION'])
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
peft_config, PromptLearningConfig
):
return PeftModel(model, peft_config, adapter_name=adapter_name)
if isinstance(peft_config, PromptLearningConfig):
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
我们从最后一句跳进去,来到了peft->peft_model.py->PeftModelForSeq2SeqLM(PeftModel)类,
在mapping.py我们看到:
MODEL_TYPE_TO_PEFT_MODEL_MAPPING = {
"SEQ_CLS": PeftModelForSequenceClassification,
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
"CAUSAL_LM": PeftModelForCausalLM,
"TOKEN_CLS": PeftModelForTokenClassification,
"QUESTION_ANS": PeftModelForQuestionAnswering,
"FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
}
所以MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]
定义了我们的模型是PeftModelForSeq2SeqLM,而传入的参数是model, peft_config, adapter_name=adapter_name.
prefix tuning
找半天没看到prefix tuning的代码,我们直接打开/root/miniconda3/envs/peft-practice/lib/python3.10/site-packages/peft/tuners/prefix_tuning.py查看,发现它改写自Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
class PrefixEncoder(torch.nn.Module):
r'''
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
prefix-length/num_virtual_tokens:20, hidden_size:768, prefix_hidden_size
Output shape: (batch-size, prefix-length, 2*layers*hidden)
'''
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
不考虑Use a two-layer MLP to encode the prefix的话,prefix tuning主要包括以下代码:
class PrefixEncoder(torch.nn.Module):
def __init__(self, config):
super().__init__()
...
self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) #num_virtual_tokens=20,token_dim=1024,num_layers=24
def forward(self, prefix: torch.Tensor):
past_key_values = self.embedding(prefix)
return past_key_values
得到的PrefixEncoder被传入peft->peft_model.py->prompt_encoder:
PrefixEncoder(
(embedding): Embedding(20, 49152) # 1024*2*24
)
self.prompt_tokens初始化为长度2*20的向量,因为T5有编码器和解码器,需要两次prefix:
self.prompt_tokens[adapter_name] = torch.arange(
config.num_virtual_tokens * config.num_transformer_submodules
).long() #20*2
# tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
# 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
# 36, 37, 38, 39])
在训练模式下,prompt_tokens复制成batch size个向量,作为prompt_encoder的输入,输出embedding:
prompt_tokens = (
self.prompt_tokens[self.active_adapter]
.unsqueeze(0)
.expand(batch_size, -1)
.to(prompt_encoder.embedding.weight.device)
)
prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
# 此时prompt_tokens.shape = (batch_size=8, num_virtual_tokens=20)
past_key_values = prompt_encoder(prompt_tokens)
torch.Size([8, 20, 49152])
这里的past_key_values来源于:https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None, #
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
但目前的past_key_values还是所有层的集合,我们需要把past_key_values分解为每一层:
past_key_values = past_key_values.view(
batch_size, #8
peft_config.num_virtual_tokens, #20
peft_config.num_layers * 2, #24*2
peft_config.num_attention_heads, #16
peft_config.token_dim // peft_config.num_attention_heads, #1024/16
)
# torch.Size([8, 20, 48, 16, 64])
因为有编码器和解码器,所以再复制一次:
past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
# torch.Size([8, 20, 96, 16, 64])
# 重排:torch.Size([96, 8, 16, 20, 64])
# 然后split成一个长度为24的tuple,每个tuple的shape:torch.Size([4, 8, 16, 20, 64])
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
peft_config.num_transformer_submodules * 2
)
也就是说past_key_values是24个层的Prefix embedding,形状为`(num_transformer_submodules * 2, batch_size, num_attention_heads, num_virtual_tokens, token_dim/num_attention_heads])
注意这里*2是因为key+value.
这些参数被传入T5:https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
# input_ids.shape: torch.Size([8, 128])
self.base_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
decoder_inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
**kwargs,
)
我们来近距离看:transformers->models->t5->modeling_t5.py->T5Attention类,这里的关键步骤是project函数中的hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
,注意project函数仅仅用于key和value。
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
if past_key_value is not None:
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
# 注意这里是重点:用串联方式
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the ` sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else:
# cross-attn
hidden_states = past_key_value
return hidden_states
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
分别计算query_states、key_states、value_states,用query和key计算attention score,得到score形状为torch.Size([8, 16, 2, 22]),所以输入X可以attend to itself以及prefix。
# hidden_states shape: torch.Size([8, 2, 1024])
# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
# query_states shape: torch.Size([8, 16, 2, 64])
# get key/value states
key_states = project(
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
# key_states shape: torch.Size([8, 16, 22, 64])
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
# value_states shape: torch.Size([8, 16, 22, 64])
# compute scores
# torch.Size([8, 16, 2, 22])
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
接下来就是经典的attention操作了。用attn_weights ([8, 16, 2, 22]) 和value_states ([8, 16, 22, 64])相乘,把22消掉,就是每个输入X的输出了。
# if key and values are already calculated
# we want only the last query position bias
# position_bias.shape: torch.Size([8, 16, 2, 22])
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) torch.Size([8, 2, 1024])
attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
outputs = outputs + (attn_weights,)
return outputs
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)