Prefix Tuning代码探索
prefix_tuning.py
import torch
from transformers import PretrainedConfig
class PrefixEncoder(torch.nn.Module):
r'''
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
'''
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(config.prefix_length, config.hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.encoder_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.encoder_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
)
else:
self.embedding = torch.nn.Embedding(config.prefix_length, config.num_hidden_layers * 2 * config.hidden_size)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
if __name__ == "__main__":
configs = {"prefix_length":20,
"hidden_size":768,
"encoder_hidden_size":768,
"num_hidden_layers":12,
"prefix_projection":False
}
prefix_encoder = PrefixEncoder(config=PretrainedConfig.from_dict(configs))
print(prefix_encoder)
batch_size = 8
prefix = torch.arange(20).long().expand(batch_size, -1)
print(prefix.shape)
output = prefix_encoder(prefix)
print(output.shape)
输出:
PrefixEncoder(
(embedding): Embedding(20, 18432)
)
torch.Size([8, 20])
torch.Size([8, 20, 18432])
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人