position embedding的代码实现(transformer/vit/swin/MAE)

代码来自:46、四种Position Embedding的原理与PyTorch手写逐行实现(Transformer/ViT/Swin-T/MAE)_哔哩哔哩_bilibili

先码住,之后再细细分析。不去试验有些都看不懂……

import torch
import torch.nn as nn

# 1. 1d absolute sincos constant embedding
#标准transformer

def create_1d_absolute_sincos_embeddings(n_pos_vec,dim):
    # pos_vec:torch.arange(n_pos)
    
    assert dim % 2 == 0, "wrong dimension!"
    position_embedding = torch.zeros(n_pos_vec.numel(),dim,dtype=torch.float)
    
    omega = torch.arange(dim//2,dtype=torch.float)
    omega/= dim/2
    omega = 1./(10000 ** omega)
    
    out = n_pos_vec[:,None] @ omega[None,:]  #n_pos_vec变成列向量,omega变成行向量
    
    emb_sin = torch.sin(out)
    emb_cos = torch.cos(out)
    
    position_embedding[:,0::2] = emb_sin
    position_embedding[:,1::2] = emb_cos
    
    return position_embedding

# 2. 1d absolute trainable embedding
# vision transformer

def create_1d_absolute_trainable_embeddings(n_pos_vec,dim):
    # n_pos_vec: torch.arange(n_pos, dtype=torch.float)
    
    position_embedding = nn.Embedding(n_pos_vec.numel(),dim)
    nn.init.constant_(position_embedding.weight,0.)
    
    return position_embedding

# 3.2d relative bias trainable embedding
# swin transformer

def create_2d_relative_bias_trainable_embeddings(n_head, height ,width, dim):
    # width:5, [0,1,2,3,4], bias=[-width+1,width-1], 2*width-1
    # height:5, [0,1,2,3,4], bias=[-height+1,height-1], 2*height-1
    position_embedding = nn.Embedding((2=width-1)*(2*height-1),n_head)
    nn.init.constant_(position_embedding.weight,0.)
    
    def get_relative_position_index(height,width):
        m1,m2 = torch.meshgrid(torch.arange(height),torch.arange(width))
        coords = torch.stack(m1,m2) # [2,height,width]
        coords_flatten = torch.flatten(coords,1) # [2,height*width]
        
        relative_coords_bias = coords_flatten[:,:,None] - coords_flatten[:,None,:] # [2,height*width,height*width]
        
        relative_coords_bias[0,:,:] += height-1
        relative_coords_bias[1,:,:] += width-1
        
        # A:2d, B:1d, B[i*cols+j] = A[i,j]
        relative_coords_bias[0,:,:] *= relative_coords_bias[1,:,:].max()+1
        
        return relative_coords_bias.sum(0) #[height*width,height*width]
    
    relative_position_bias = get_relative_position_index(height,width)
    bias_embedding = position_embedding(torch.flatten(relative_position_bias)).reshape[height*width,height*width,n_head] # [height*width,height*width,n_head]
    
    bias_embedding = bias_embedding.permute(2,0,1).unsqueeze(0) # [1,n_head,height*width,height*width]
    
    return bias_embedding
        
# 4.2d absolute constant sincos embedding
# masked autoencoder

def create_2d_absolute_sincos_embeddings(height ,width, dim):
    
    assert dim % 4 ==0, "wrong dimension!"
    
    position_embedding = torch.zeros(height*width , dim)
    m1,m2 = torch.meshgrid(torch.arange(height,dtype=torch.float),torch.arange(width,dtype=torch.float))
    coords = torch.stack(m1,m2) # [2,height,width]
    
    height_embedding = create_1d_absolute_trainable_embeddings(torch.flatten(coords[0]),dim//2) # [height*width,dim//2]
    width_embedding = create_1d_absolute_trainable_embeddings(torch.flatten(coords[1]),dim//2) # [height*width,dim//2]
    
    positon_embedding[:,:dim//2] = height_embedding
    positon_embedding[:,dim//2:] = width_embedding
    
    return position_embedding

    
if __name__ == "__main__":
    n_pos = 4
    dim = 4
    n_pos_vec = torch.arange(n_pos,dtype=torch.float)
    pe = create_1d_absolute_sincos_embeddings(n_pos_vec,dim)

 

posted @ 2022-12-23 21:43  实数集  阅读(1298)  评论(0)    收藏  举报