Attention Mechanisms-multihead-attention课后题

1、Visualize attention weights of multiple heads in this experiment.

from matplotlib import pyplot as plt

out = attention.attention.attention_weights.detach()
# out shape is (batch_size*head_size, queries_size, key_value_size)
print(out.shape)
# out1 shape is (batch_size, head_size, queries_size, key_value_size)
out1=out.reshape(-1, num_heads, out.shape[1],out.shape[2])

print(out1.shape)
d2l.show_heatmaps(out1, xlabel='Keys', ylabel='Queries')
print(out1)

2、Suppose that we have a trained model based on multi-head attention and we want to prune least important attention heads to increase the prediction speed. How can we design experiments to measure the importance of an attention head?

1、DotProductAttention的attention为[batch_size**num_heads,queries,num_hiddens/num_heads],reshape为[batch_size,num_heads,queries,num_hiddens/num_heads]
2、设置head的权重的参数,参数为[1, num_heads,1,1]
3、将DotProductAttention的attention乘以head权重[batch_size,num_heads,queries,num_hiddens/num_heads]
4、将各num_heads进行拼接

5、运行一次,然后就可以获取定义的head的权重

class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.head_attention = nn.Parameter(torch.rand(1, num_heads, 1, 1))
        
    def forward(self, queries, keys, values, valid_lens):
        # Shape of `queries`, `keys`, or `values`:
        # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
        # Shape of `valid_lens`:
        # (`batch_size`,) or (`batch_size`, no. of queries)
        # After transposing, shape of output `queries`, `keys`, or `values`:
        # (`batch_size` * `num_heads`, no. of queries or key-value pairs,
        # `num_hiddens` / `num_heads`)
        print("queries,", queries.shape)
        print("self.W_q(queries),", self.W_q(queries).shape)
        print("keys,", keys.shape)
        print("self.W_k(keys),", self.W_k(keys).shape)
        print("values,", values.shape)
        print("self.W_v(values),", self.W_v(values).shape)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        print("*queries", queries.shape)
        print("*keys", keys.shape)
        print("*values", values.shape)
        
        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for
            # `num_heads` times, then copy the next item, and so on
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_heads,
                                                 dim=0)

        # Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
        # `num_hiddens` / `num_heads`)
        output = self.attention(queries, keys, values, valid_lens)
        #***********************************different
        # Shape is (`batch_size` , `num_heads`, no. of queries,`num_hiddens` / `num_heads`)

        output = output.reshape(-1, num_heads, output.shape[1], output.shape[2])
        mm = nn.Softmax(dim=1)
        self.head_attenion_tran = mm(self.head_attention)
        # Shape is (`batch_size` , `num_heads`, no. of queries,`num_hiddens`)
        # *****************************************************************************************
        # Shape of self.head_attenion_tran is (1, num_head, 1, 1)
        output_concat3 = (output*self.head_attenion_tran)
        print("output_concat3", output_concat3.shape)
        output_concat3 = output_concat3.permute(0, 2, 1, 3)
        output_concat = output_concat3.reshape(X.shape[0], X.shape[1], -1)
        print("output_concat", output_concat.shape)
        # *****************************************************************************************

        return self.W_o(output_concat)

  

posted @ 2021-05-27 17:37  哈哈哈喽喽喽  阅读(80)  评论(0编辑  收藏  举报