prompt gating代码探索

import torch

def promptGating(gating, adding, x):
    '''
    gating: (num_prefix, dim)  
    adding: (num_prefix, dim) 
    x: (seq_length, batch_size, dim) 
    '''
    if gating is not None:
        gating = gating.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1) # (num_prefix,batch_size,dim)
        gating = torch.cat([gating, torch.ones([x.size(0)-gating.size(0), x.size(1), x.size(2)])], axis=0) 
        # (seq_length, batch_size, dim)
        x = x * gating # prefix之外*1

        if adding is not None: #相当于加上bias
            adding = adding.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1) 
            adding = torch.cat([adding, torch.zeros([x.size(0)-adding.size(0), x.size(1), x.size(2)])], axis=0)

            x = adding + x  # prefix之外+0
    return x

if __name__ == "__main__":
    num_prompt, batch_size, seq_length, dim = 2, 8, 22, 1024
    gating = torch.randn(num_prompt, dim) 
    adding = torch.randn(num_prompt, dim) 
    x = torch.randn(seq_length, batch_size, dim) 

    new_x = promptGating(gating, adding, x)
    print(new_x.shape)
    # 输出:torch.Size([22, 8, 1024])

posted @   鸽鸽的书房  阅读(25)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示