position embedding的代码实现(transformer/vit/swin/MAE)
代码来自:46、四种Position Embedding的原理与PyTorch手写逐行实现(Transformer/ViT/Swin-T/MAE)_哔哩哔哩_bilibili
先码住,之后再细细分析。不去试验有些都看不懂……
import torch import torch.nn as nn # 1. 1d absolute sincos constant embedding #标准transformer def create_1d_absolute_sincos_embeddings(n_pos_vec,dim): # pos_vec:torch.arange(n_pos) assert dim % 2 == 0, "wrong dimension!" position_embedding = torch.zeros(n_pos_vec.numel(),dim,dtype=torch.float) omega = torch.arange(dim//2,dtype=torch.float) omega/= dim/2 omega = 1./(10000 ** omega) out = n_pos_vec[:,None] @ omega[None,:] #n_pos_vec变成列向量,omega变成行向量 emb_sin = torch.sin(out) emb_cos = torch.cos(out) position_embedding[:,0::2] = emb_sin position_embedding[:,1::2] = emb_cos return position_embedding # 2. 1d absolute trainable embedding # vision transformer def create_1d_absolute_trainable_embeddings(n_pos_vec,dim): # n_pos_vec: torch.arange(n_pos, dtype=torch.float) position_embedding = nn.Embedding(n_pos_vec.numel(),dim) nn.init.constant_(position_embedding.weight,0.) return position_embedding # 3.2d relative bias trainable embedding # swin transformer def create_2d_relative_bias_trainable_embeddings(n_head, height ,width, dim): # width:5, [0,1,2,3,4], bias=[-width+1,width-1], 2*width-1 # height:5, [0,1,2,3,4], bias=[-height+1,height-1], 2*height-1 position_embedding = nn.Embedding((2=width-1)*(2*height-1),n_head) nn.init.constant_(position_embedding.weight,0.) def get_relative_position_index(height,width): m1,m2 = torch.meshgrid(torch.arange(height),torch.arange(width)) coords = torch.stack(m1,m2) # [2,height,width] coords_flatten = torch.flatten(coords,1) # [2,height*width] relative_coords_bias = coords_flatten[:,:,None] - coords_flatten[:,None,:] # [2,height*width,height*width] relative_coords_bias[0,:,:] += height-1 relative_coords_bias[1,:,:] += width-1 # A:2d, B:1d, B[i*cols+j] = A[i,j] relative_coords_bias[0,:,:] *= relative_coords_bias[1,:,:].max()+1 return relative_coords_bias.sum(0) #[height*width,height*width] relative_position_bias = get_relative_position_index(height,width) bias_embedding = position_embedding(torch.flatten(relative_position_bias)).reshape[height*width,height*width,n_head] # [height*width,height*width,n_head] bias_embedding = bias_embedding.permute(2,0,1).unsqueeze(0) # [1,n_head,height*width,height*width] return bias_embedding # 4.2d absolute constant sincos embedding # masked autoencoder def create_2d_absolute_sincos_embeddings(height ,width, dim): assert dim % 4 ==0, "wrong dimension!" position_embedding = torch.zeros(height*width , dim) m1,m2 = torch.meshgrid(torch.arange(height,dtype=torch.float),torch.arange(width,dtype=torch.float)) coords = torch.stack(m1,m2) # [2,height,width] height_embedding = create_1d_absolute_trainable_embeddings(torch.flatten(coords[0]),dim//2) # [height*width,dim//2] width_embedding = create_1d_absolute_trainable_embeddings(torch.flatten(coords[1]),dim//2) # [height*width,dim//2] positon_embedding[:,:dim//2] = height_embedding positon_embedding[:,dim//2:] = width_embedding return position_embedding if __name__ == "__main__": n_pos = 4 dim = 4 n_pos_vec = torch.arange(n_pos,dtype=torch.float) pe = create_1d_absolute_sincos_embeddings(n_pos_vec,dim)
浙公网安备 33010602011771号