Tensorflow池化

# -*- encoding: utf-8 -*-
import tensorflow as tf

# 定义一张4单通道*4图片
# data = tf.random.truncated_normal(shape=(1, 1, 4, 4))
data = tf.constant(
    [[[[1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]]]],
    dtype="float32"  # avg_pool 要求都是 float32 类型
)

# reshape 成 batch_size, height, width, n_channels ,因为这是 max_pool函数要求的格式
# batch_size=1,因为就一张图片, 高和宽都是4,通道是1
data = tf.reshape(data, [1, 4, 4, 1])

# pool_size 设置成 1,4,1,1; 窗口是[1,1,1,1]
# 1,4,1,1的数组举个例子:
# [
#     [[1]],
#     [[1]],
#     [[1]],
#     [[1]],
# ]
# 第一个数字要与batch_size保持一致,后面的shape定义了一个扫描块,也就是纵向按列扫描
# strides=[1,1,1,1] 每次移动一个单位, 最后输出应是: [1,1,4,1]

output1 = tf.nn.max_pool(data, [1, 4, 1, 1], [1, 1, 1, 1], padding='VALID')
print(output1)

# tf.Tensor(
#     [[[[13]
#        [14]
#        [15]
#        [16]]]], shape=(1, 1, 4, 1), dtype=int32)

output2 = tf.nn.avg_pool(data, [1, 4, 1, 1], [1, 1, 1, 1], padding='VALID')
print(output2)

#
# tf.Tensor(
#     [[[[ 7.]
#        [ 8.]
#        [ 9.]
#        [10.]]]], shape=(1, 1, 4, 1), dtype=float32)

注意事项:

  1. 图片的通道,描述图片用RGB3种颜色,每个颜色都需要一个二维矩阵,成为一个通道
  2. avg_pool 需要输入的数据类型为float, 否则报错:tensorflow.python.framework.errors_impl.NotFoundError: Could not find valid device for node.
  3. 输入数据的格式需要为一个4维数组,shape=(batch_size, height, width, n_channels ) ,这个格式专门为图片设定的,其他类型要自己转换

池化说明:

TF提供了tf.keras.layers.AvgPool2D,tf.keras.layers.MaxPool2D 来搭建池化层。

posted @ 2020-11-24 10:43  oaksharks  阅读(154)  评论(0编辑  收藏  举报