Tensor的合并与分割
先来看一下有哪些接口用来进行张量的合并与分割:
tf.concat用来进行张量的拼接,tf.stack用来进行张量的堆叠,tf.split用来进行张量的分割,tf.unstack是tf.split的一种,也用来进行张量分割
1.tf.concat
参数axis代表将要合并的维度
# 假设a代表四个班的成绩(每班35人,8个科目),b代表2个班的成绩 a = tf.ones([4,35,8]) b = tf.ones([2,35,8]) # 使用concat进行合并得到6个班的成绩 c = tf.concat([a,b],axis=0) # (6,35,8) print(c.shape)
2.tf.stack(用于创建一个新的维度)
# 假设a代表A学校的四个班的成绩(每班35人,8个科目),b代表B学校四个班的成绩 a = tf.ones([4,35,8]) b = tf.ones([4,35,8]) # 使用stack进行合并得到6个班的成绩 c = tf.stack([a,b],axis=0) # (2,4,35,8) print(c.shape)
3.tf.unstack(对某维度进行等分)
# 假设a代表A学校的四个班的成绩(每班35人,8个科目),b代表B学校四个班的成绩 a = tf.ones([4,35,8]) b = tf.ones([4,35,8]) # 使用stack进行合并得到6个班的成绩 c = tf.stack([a,b],axis=0) # (2,4,35,8) print(c.shape) aa,bb=tf.unstack(c,axis=0) # (4,35,8) print(aa.shape,bb.shape) res=tf.unstack(c,axis=3) # (2,4,35) print(res[0].shape,res[7].shape)
4.tf.split(按比例打散)
# 假设a代表A学校的四个班的成绩(每班35人,8个科目),b代表B学校四个班的成绩 a = tf.ones([4,35,8]) b = tf.ones([4,35,8]) # 使用stack进行合并得到6个班的成绩 c = tf.stack([a,b],axis=0) # (2,4,35,8) print(c.shape) res = tf.split(c,axis=3,num_or_size_splits=2) # 2,(2,4,35,4) print(len(res),res[0].shape,res[1].shape) res = tf.split(c,axis=3,num_or_size_splits=[2,2,4]) # 3 (2,4,35,2) (2,4,35,2) (2,4,35,4) print(len(res),res[0].shape,res[1].shape,res[2].shape)