Attention Mechanisms-multihead-attention课后题
1、Visualize attention weights of multiple heads in this experiment.
from matplotlib import pyplot as plt out = attention.attention.attention_weights.detach() # out shape is (batch_size*head_size, queries_size, key_value_size) print(out.shape) # out1 shape is (batch_size, head_size, queries_size, key_value_size) out1=out.reshape(-1, num_heads, out.shape[1],out.shape[2]) print(out1.shape) d2l.show_heatmaps(out1, xlabel='Keys', ylabel='Queries') print(out1)
2、Suppose that we have a trained model based on multi-head attention and we want to prune least important attention heads to increase the prediction speed. How can we design experiments to measure the importance of an attention head?
1、DotProductAttention的attention为[batch_size**num_heads,queries,num_hiddens/num_heads],reshape为[batch_size,num_heads,queries,num_hiddens/num_heads]
2、设置head的权重的参数,参数为[1, num_heads,1,1]
3、将DotProductAttention的attention乘以head权重[batch_size,num_heads,queries,num_hiddens/num_heads]
4、将各num_heads进行拼接
5、运行一次,然后就可以获取定义的head的权重
class MultiHeadAttention(nn.Module): def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) self.head_attention = nn.Parameter(torch.rand(1, num_heads, 1, 1)) def forward(self, queries, keys, values, valid_lens): # Shape of `queries`, `keys`, or `values`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`) # Shape of `valid_lens`: # (`batch_size`,) or (`batch_size`, no. of queries) # After transposing, shape of output `queries`, `keys`, or `values`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) print("queries,", queries.shape) print("self.W_q(queries),", self.W_q(queries).shape) print("keys,", keys.shape) print("self.W_k(keys),", self.W_k(keys).shape) print("values,", values.shape) print("self.W_v(values),", self.W_v(values).shape) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) print("*queries", queries.shape) print("*keys", keys.shape) print("*values", values.shape) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for # `num_heads` times, then copy the next item, and so on valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0) # Shape of `output`: (`batch_size` * `num_heads`, no. of queries, # `num_hiddens` / `num_heads`) output = self.attention(queries, keys, values, valid_lens) #***********************************different # Shape is (`batch_size` , `num_heads`, no. of queries,`num_hiddens` / `num_heads`) output = output.reshape(-1, num_heads, output.shape[1], output.shape[2]) mm = nn.Softmax(dim=1) self.head_attenion_tran = mm(self.head_attention) # Shape is (`batch_size` , `num_heads`, no. of queries,`num_hiddens`) # ***************************************************************************************** # Shape of self.head_attenion_tran is (1, num_head, 1, 1) output_concat3 = (output*self.head_attenion_tran) print("output_concat3", output_concat3.shape) output_concat3 = output_concat3.permute(0, 2, 1, 3) output_concat = output_concat3.reshape(X.shape[0], X.shape[1], -1) print("output_concat", output_concat.shape) # ***************************************************************************************** return self.W_o(output_concat)