Keras实现Hierarchical Attention Network时的一些坑
Reshape
对于的张量x,x.shape=(a, b, c, d)的情况
若调用keras.layer.Reshape(target_shape=(-1, c, d)),
处理后的张量形状为(?, ?, c, d)
若调用tf.reshape(x, shape=[-1, c, d])
处理后的张量形状为(a*b, c, d)
为了在keras代码中实现tf.reshape的效果,用lambda层做,
调用Lambda(lambda x: tf.reshape(x, shape=[-1, c, d]))(x)
nice and cool.
输出Attention的打分
这里,我们希望attention层能够输出attention的score,而不只是计算weighted sum。
在使用时
score = Attention()(x)
weighted_sum = MyMerge()([score, x])
class Attention(Layer):
def __init__(self, **kwargs):
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 3
self.w = self.add_weight(name="attention_weight",
shape=(input_shape[-1],
input_shape[-1]),
initializer='uniform',
trainable=True
)
self.b = self.add_weight(name="attention_bias",
shape=(input_shape[-1],),
initializer='uniform',
trainable=True
)
self.v = self.add_weight(name="attention_v",
shape=(input_shape[-1], 1),
initializer='uniform',
trainable=True
)
super(Attention, self).build(input_shape)
def call(self, inputs):
x = inputs
att = K.tanh(K.dot(x, self.w) + self.b)
att = K.softmax(K.dot(att, self.v))
print(att.shape)
return att
def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[1], 1
class MyMerge(Layer):
def __init__(self, **kwargs):
super(MyMerge, self).__init__(**kwargs)
def call(self, inputs):
att = inputs[0]
x = inputs[1]
att = tf.tile(att, [1, 1, x.shape[-1]])
outputs = tf.multiply(att, x)
outputs = K.sum(outputs, axis=1)
return outputs
def compute_output_shape(self, input_shape):
return input_shape[1][0], input_shape[1][2]
keras中Model的嵌套
这边是转载自https://github.com/uhauha2929/examples/blob/master/Hierarchical%20Attention%20Networks%20.ipynb
可以看到,sentEncoder是Model类型,在后面的时候通过TimeDistributed(sentEncoder),当成一个层那样被调用。
embedding_layer = Embedding(len(word_index) + 1,
EMBEDDING_DIM,
input_length=MAX_SENT_LENGTH)
sentence_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sentence_input)
l_lstm = Bidirectional(LSTM(100))(embedded_sequences)
sentEncoder = Model(sentence_input, l_lstm)
review_input = Input(shape=(MAX_SENTS,MAX_SENT_LENGTH), dtype='int32')
review_encoder = TimeDistributed(sentEncoder)(review_input)
l_lstm_sent = Bidirectional(LSTM(100))(review_encoder)
preds = Dense(2, activation='softmax')(l_lstm_sent)
model = Model(review_input, preds)