"""
Created on 2020/11/29 19:59.
@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import Add, Concatenate, UpSampling2D, SeparableConv2D
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape
from tensorflow.keras import Model
def Xception(inputs):
x = Conv2D(32, kernel_size=(3, 3), strides=(2, 2), padding='same')(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x_shortcut = Conv2D(128, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x_sep = SeparableConv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(128, kernel_size=(3, 3), strides=(2, 2), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
x_shortcut = Conv2D(256, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep_act = Activation('relu')(x_sep)
x_sep = SeparableConv2D(256, kernel_size=(3, 3), strides=(2, 2), padding='same')(x_sep_act)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
x_low_level_feature = x_sep_act
x_shortcut = Conv2D(728, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(2, 2), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
for i in range(16):
x_shortcut = x
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
x_shortcut = Conv2D(1024, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
x = SeparableConv2D(1536, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(1536, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(2048, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x, x_low_level_feature
def ASPP(x, filter_num, old_filter_num):
x_pool = GlobalAveragePooling2D()(x)
x_pool = Reshape((1, 1, old_filter_num))(x_pool)
x_pool = Conv2D(filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_pool)
x_pool = BatchNormalization()(x_pool)
x_pool = Activation('relu')(x_pool)
x_pool = UpSampling2D(size=(16, 16))(x_pool)
x_1 = Conv2D(filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
x_1 = BatchNormalization()(x_1)
x_1 = Activation('relu')(x_1)
x_6 = Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation_rate=6)(x)
x_6 = BatchNormalization()(x_6)
x_6 = Activation('relu')(x_6)
x_12 = Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation_rate=12)(x)
x_12 = BatchNormalization()(x_12)
x_12 = Activation('relu')(x_12)
x_18 = Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation_rate=18)(x)
x_18 = BatchNormalization()(x_18)
x_18 = Activation('relu')(x_18)
x = Concatenate()([x_pool, x_1, x_6, x_12, x_18])
x = Conv2D(filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x
def build_model(tif_size, bands, class_num):
from pathlib import Path
import sys
print('===== %s =====' % Path(__file__).name)
print('===== %s =====' % sys._getframe().f_code.co_name)
inputs = Input(shape=(tif_size, tif_size, bands))
x, x_low_level_feature = Xception(inputs)
x_high_level_feature = ASPP(x, 256, 2048)
x_low_level_feature = Conv2D(48, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_low_level_feature)
x_high_level_feature = UpSampling2D(size=(4, 4))(x_high_level_feature)
x = Concatenate()([x_low_level_feature, x_high_level_feature])
x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = UpSampling2D(size=(4, 4))(x)
x = Conv2D(class_num, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='softmax')(x)
mymodel = Model(inputs, x)
return mymodel
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!