import torch
def promptGating(gating, adding, x):
'''
gating: (num_prefix, dim)
adding: (num_prefix, dim)
x: (seq_length, batch_size, dim)
'''
if gating is not None:
gating = gating.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1)
gating = torch.cat([gating, torch.ones([x.size(0)-gating.size(0), x.size(1), x.size(2)])], axis=0)
x = x * gating
if adding is not None:
adding = adding.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1)
adding = torch.cat([adding, torch.zeros([x.size(0)-adding.size(0), x.size(1), x.size(2)])], axis=0)
x = adding + x
return x
if __name__ == "__main__":
num_prompt, batch_size, seq_length, dim = 2, 8, 22, 1024
gating = torch.randn(num_prompt, dim)
adding = torch.randn(num_prompt, dim)
x = torch.randn(seq_length, batch_size, dim)
new_x = promptGating(gating, adding, x)
print(new_x.shape)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人