keras中的 concatenate() 详解
最近看模态融合,用到了 keras 中的 concatenate() 函数,之前没有搞明白其中的 axis 这个参数是什么意思。后来经过一番研究,总算是搞明白了。
先看代码
import numpy as np
import keras.backend as K
import tensorflow as tf
a = K.variable(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]))
b = K.variable(np.array([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]))
c1 = K.concatenate([a, b], axis=0)
c2 = K.concatenate([a, b], axis=1)
c3 = K.concatenate([a, b], axis=2)
#试试默认的参数,其实就是从倒数第一个维度进行融合的。
c4 = K.concatenate([a, b])
c5 = K.concatenate([a, b],axis=-1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print('***************')
print(a.shape,b.shape)
print('***************')
print('*****C1******',c1.shape)
print(sess.run(c1))
print()
print('*****C2******',c2.shape)
print(sess.run(c2))
print()
print('*****C3******',c3.shape)
print(sess.run(c3))
print()
print('*****C4******',c4.shape)
print(sess.run(c4))
print('*****C5******',c5.shape)
print(sess.run(c5))
再看看输出的效果:
axis=n表示从第n个维度进行拼接,对于一个三维矩阵,axis的取值可以为[-3, -2, -1, 0, 1, 2]。
axis=-2,意思是从倒数第2个维度进行拼接,对于三维矩阵而言,这就等同于axis=1。
axis=-1,意思是从倒数第1个维度进行拼接,对于三维矩阵而言,这就等同于axis=2。
简单点理解:
可能从图像上来理解比较复杂,但是如果从数学的角度来 看很简单,就比如上边的例子
两个 (2,2,2)(2,2,2)的数组进行融合,,
- 第一个维度融合就是(4,2,2),即 axis=0
- 第二个维度融合就是(2,4,2),即 axis=1
- 第三个维度融合就是(2,2,4),即 axis=2
参考文献
[1]https://blog.csdn.net/leviopku/article/details/82380710
[2]https://zhuanlan.zhihu.com/p/58672698