TF2实现语义分割网络PSPNet


"""
Created on 2020/11/29 19:51.

@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""


import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, Activation
from tensorflow.keras.layers import Add, ZeroPadding2D, AveragePooling2D, Lambda, Concatenate
from tensorflow.keras import Model
import tensorflow.keras.backend as K


tf.compat.v1.disable_eager_execution()
IMAGE_ORDERING = 'channels_last'
if IMAGE_ORDERING == 'channels_first':  # 'NCHW'
    MERGE_AXIS = 1
elif IMAGE_ORDERING == 'channels_last':  # 'NHWC'
    MERGE_AXIS = -1


def identity_block(input_x, filter_list, dilation_rate=1):  # 实线残差块
    filters1, filters2, filters3 = filter_list

    x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=1, padding='same')(input_x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=1, padding='same', dilation_rate=dilation_rate)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    x = Add()([x, input_x])
    x = Activation('relu')(x)

    return x


def conv_block(input_x, filter_list, strides=2, dilation_rate=1):  # 虚线残差块
    filters1, filters2, filters3 = filter_list

    x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=strides, padding='same')(input_x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=1, padding='same', dilation_rate=dilation_rate)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    # 捷径
    shortcut = Conv2D(filters=filters3, kernel_size=(1, 1), strides=strides, padding='same')(input_x)
    shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = Activation('relu')(x)

    return x


def get_resnet50_encoder(inputs, downsample_factor=16):
    if downsample_factor == 16:
        block4_stride = 2
        block4_dilation = 1
        block5_dilation = 2
    elif downsample_factor == 8:
        block4_stride = 1
        block4_dilation = 2
        block5_dilation = 4

    block_list = [3, 4, 6, 3]

    # conv1
    x = ZeroPadding2D(padding=(1, 1))(inputs)
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=2, padding='valid')(x)
    f1 = x
    x = BatchNormalization(axis=-1)(x)
    x = Activation('relu')(x)
    # -----------------------------------------------------
    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='valid')(x)
    x = BatchNormalization(axis=-1)(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Conv2D(filters=128, kernel_size=(3, 3), strides=1, padding='valid')(x)
    x = BatchNormalization(axis=-1)(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D(padding=(1, 1))(x)
    x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)
    # -----------------------------------------------------

    # conv2_x
    x = conv_block(input_x=x, filter_list=[64, 64, 256], strides=1)
    for i in range(block_list[0] - 1):
        x = identity_block(input_x=x, filter_list=[64, 64, 256])
    f2 = x

    # conv3_x
    x = conv_block(input_x=x, filter_list=[128, 128, 512])
    for i in range(block_list[1] - 1):
        x = identity_block(input_x=x, filter_list=[128, 128, 512])
    f3 = x

    # conv4_x
    x = conv_block(input_x=x, filter_list=[256, 256, 1024], strides=block4_stride)
    for i in range(block_list[2] - 1):
        x = identity_block(input_x=x, filter_list=[256, 256, 1024], dilation_rate=block4_dilation)
    f4 = x

    # conv5_x
    x = conv_block(input_x=x, filter_list=[512, 512, 2048], strides=1, dilation_rate=block5_dilation)
    for i in range(block_list[3] - 1):
        x = identity_block(input_x=x, filter_list=[512, 512, 2048], dilation_rate=block5_dilation)
    f5 = x

    return [f1, f2, f3, f4, f5]


def pool_block(feats, pool_factor):
    h = K.int_shape(feats)[1]
    w = K.int_shape(feats)[2]
    pool_size = strides = [int(np.round(float(h) / pool_factor)), int(np.round(float(w) / pool_factor))]
    x = AveragePooling2D(pool_size=pool_size, strides=strides, padding='same', data_format=IMAGE_ORDERING)(feats)
    x = Conv2D(filters=512, kernel_size=(1, 1), padding='same', data_format=IMAGE_ORDERING, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Lambda(
        lambda x: tf.compat.v1.image.resize_images(x, (K.int_shape(feats)[1], K.int_shape(feats)[2]),
        align_corners=True)
    )(x)
    return x


def build_model(tif_size, bands, class_num):
    from pathlib import Path
    import sys
    print('===== %s =====' % Path(__file__).name)
    print('===== %s =====' % sys._getframe().f_code.co_name)

    inputs = Input(shape=(tif_size, tif_size, bands))

    aux_branch = True  # 是否启用辅助损失分支
    downsample_factor = 16  # 8 or 16
    inputs_size = (tif_size, tif_size, bands)

    levels = get_resnet50_encoder(inputs, downsample_factor)
    [f1, f2, f3, f4, f5] = levels

    # -------------------------------------#
    #	PSP模块
    #	分区域进行池化
    # -------------------------------------#
    pool_factors = [1, 2, 3, 6]
    o = f5
    pool_outs = [o]
    for p in pool_factors:
        pooled = pool_block(o, p)
        pool_outs.append(pooled)
    o = Concatenate(axis=MERGE_AXIS)(pool_outs)
    # -------------------------------------#

    o = Conv2D(512, (3, 3), data_format=IMAGE_ORDERING, padding='same', use_bias=False)(o)
    o = BatchNormalization()(o)
    o = Activation('relu')(o)
    o = Dropout(0.1)(o)

    o = Conv2D(class_num, (1, 1), data_format=IMAGE_ORDERING, padding='same')(o)
    o = Lambda(lambda x: tf.compat.v1.image.resize_images(x, (inputs_size[1], inputs_size[0]), align_corners=True))(o)
    o = Activation("softmax", name="main")(o)

    # 辅助损失分支
    if aux_branch:
        f4 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', use_bias=False)(f4)
        f4 = BatchNormalization()(f4)
        f4 = Activation('relu')(f4)
        f4 = Dropout(0.1)(f4)
        f4 = Conv2D(class_num, (1, 1), data_format=IMAGE_ORDERING, padding='same')(f4)
        f4 = Lambda(
            lambda x: tf.compat.v1.image.resize_images(x, (inputs_size[1], inputs_size[0]), align_corners=True))(f4)
        f4 = Activation("softmax", name="aux")(f4)
        model = Model(inputs, [f4, o])  # 输出2个结果:预测时仅用o即可,f4仅用来辅助,结果上o优于f4
        return model
    else:
        model = Model(inputs, [o])
        return model

posted @   yub4by  阅读(76)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!
点击右上角即可分享
微信分享提示