Tensorflow API 学习(1)-tf.slice()
slice()函数原型为:
tf.slice(input_, begin, size, name=None)
函数有4个参数:
1,input_ :图片的矩阵输入格式。
2,begin :开始截取的位置(输入矩阵的某一点,通常是[x,y,z]的形式)
3,size :从开始截取点向各维度截取的距离(通常也是[x,y,z]的形式)
4,name :该tensor的名字。
tensor(a,b,c)
tensor(z,y,x) 向量在三维坐标的表示如三维坐标轴。tf.slice()参数顺序也是(z,y,x)。
官网例子:
# 'input' is [[[1, 1, 1], [2, 2, 2]],
# [[3, 3, 3], [4, 4, 4]],
# [[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
[4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
[[5, 5, 5]]]
解读:
1,input是一个3维向量,作为tf.slice()函数的输入值(待截取的tensor)
2,第二个参数 [ 1,0,0 ] 是截取的起始点,这里就是第2行的第一个数字 “3”
3,第三个参数有3个示例,只讲(1)(3):
(1) [ 1,1,3 ] 是截取的距离,第一个维度截取1个距离,于是首先截取出[ [ [ 3,3,3 ] , [ 4,4,4 ] ] ] 这部分。第二个维度截取1个距离,那么再截出[ [ [ 3,3,3 ] ] ]这部分。第3个维度截取3个距离,即将全部3个元素截取,得到结果。
(3)[ 2,1,3 ] 第一个维度是2,于是截出来
[ [ [ 3,3,3 ] , [ 4,4,4 ] ,
[ [ 5,5,5 ] , [ 6,6,6 ] ] ]
这部分。第二个维度是1,再截取一个距离,得:
[ [ [ 3,3,3 ] ,
[ [ 5,5,5 ] ] ]
这部分。第三个维度截3个距离,得到结果,若截2个距离,得:
[ [ [ 3,3 ] ,
[ [ 5,5 ] ] ]
注:
(1)第3个参数中可以用-1,如[1,-1,-1],表示第2,3维度从起点一直截取到最后。
(2)多维向量不要理解为线,面,体之类的,那样的话3维以上的点就会对应错。有多少层符号"[ ]",就有多少维,从外层向内层,维度依次增加。