tensorflow2.3实现Segformer
import tensorflow as tf import math from math import sqrt from tensorflow import keras import sys from functools import partial import cv2 import numpy as np from einops import rearrange def cast_tuple(val, depth): return val if isinstance(val, tuple) else (val,) * depth # print(cast_tuple((2,1),5)) # print(cast_tuple(2,5)) # # dims = (32, 64, 160, 256), # heads = (1, 2, 5, 8), # ff_expansion = (8, 8, 4, 4), # reduction_ratio = (8, 4, 2, 1), # num_layers = 2, # channels = 3, # decoder_dim = 256, # num_classes = 4 # c = dims, heads, ff_expansion, reduction_ratio, num_layers = map( # partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers) # ) # # 分组卷积 class DsConv2d(keras.layers.Layer): def __init__(self, dim_in, dim_out, kernel_size, padding, stride = (1,1), bias = True): super(DsConv2d, self).__init__() self.net = keras.Sequential([ keras.layers.BatchNormalization(), keras.layers.Conv2D(dim_in, kernel_size, padding = padding, groups = dim_in, strides = stride, use_bias = bias), # 改进:可以更深度 keras.layers.Conv2D(dim_out, kernel_size = 1, use_bias = bias) ]) def call(self, x): # print('x:',x.shape) return self.net(x) # d_conv = DsConv2d(dim_in=3,dim_out=512,kernel_size=3,padding='SAME') # input_data = keras.layers.Input(shape=(101,101,3)) # deal1 = d_conv(input_data) # print(deal1.shape) class PreNorm(tf.keras.Model): def __init__(self, dim, fn): super().__init__() self.norm = tf.keras.layers.LayerNormalization(epsilon=1e-5) self.fn = fn def call(self, x): return self.fn(self.norm(x)) class EfficientSelfAttention(keras.layers.Layer): def __init__(self,*,dim,heads,reduction_ratio): super(EfficientSelfAttention,self).__init__() self.scale = (dim // heads) ** -0.5 self.heads = heads self.to_q = keras.layers.Conv2D(dim, 1, use_bias = False) self.to_k = keras.layers.Conv2D(dim * 1, reduction_ratio, strides = reduction_ratio, use_bias = False) self.to_v = keras.layers.Conv2D(dim * 1, reduction_ratio, strides=reduction_ratio, use_bias=False) self.to_out = keras.layers.Conv2D(dim, 1, use_bias = False) def call(self, x): h, w = x.shape[-3:-1] heads = self.heads q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x)) # 顺序 q, k, v = map(lambda t: rearrange(t, 'b x y (h c)-> (b h) (x y) c', h = heads), (q, k, v)) sim = tf.einsum('b i d, b j d -> b i j', q, k) * self.scale attn = tf.nn.softmax(sim,axis = -1) out = tf.einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) (x y) c -> b x y (h c)', h = heads, x = h, y = w) return self.to_out(out) def gelu(x): """Gaussian Error Linear Unit. This is a smoother version of the RELU. Original paper: https://arxiv.org/abs/1606.08415 Args: x: float Tensor to perform activation. Returns: `x` with the GELU activation applied. """ cdf = 0.5 * (1.0 + tf.tanh( (math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3))))) return x * cdf class GELU(keras.layers.Layer): def call(self, inputs, **kwargs): cdf = 0.5 * (1.0 + tf.tanh( (math.sqrt(2 / math.pi) * (inputs + 0.044715 * tf.pow(inputs, 3))))) return inputs * cdf class MixFeedForward(keras.layers.Layer): def __init__( self, *, dim, expansion_factor ): super().__init__() hidden_dim = dim * expansion_factor self.net = keras.Sequential([ keras.layers.Conv2D(hidden_dim, 1), DsConv2d(hidden_dim,hidden_dim, 3, padding = 'SAME'), # 改进 :加入一些空洞卷积 keras.layers.BatchNormalization(), GELU(), keras.layers.Conv2D(dim, 1)] ) def call(self, x): return self.net(x) # class my_Unfold(keras.layers.Layer): # def __init__(self,*,kernel, stride , padding,rates=[1, 1, 1, 1]): # self.kernel = kernel # self.stride = stride # self.padding = padding # self.rates = rates # def call(self, inputs, **kwargs): # return tf.image.extract_patches(images=inputs,sizes=[1,self.kernel,self.kernel,1], strides=[1,self.stride,self.stride,1],rates=[1,1,1,1],padding='SAME') class MiT(keras.layers.Layer): # dims = (32, 64, 160, 256),heads = (1, 2, 5, 8),ff_expansion = (8, 8, 4, 4), # reduction_ratio = (16, 4, 2, 1),num_layers = 2,channels = 3,decoder_dim = 256,num_classes = 4 def __init__(self,*,channels,dims,heads,ff_expansion,reduction_ratio,num_layers,stage_kernel_stride_pad): super().__init__() # self.stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1)) # self.stage_kernel_stride_pad = ((3, 2), (3, 2), (3, 2), (3, 2)) # patch self.stage_kernel_stride_pad = stage_kernel_stride_pad dims = (channels, *dims) dim_pairs = list(zip(dims[:-1], dims[1:])) self.stages = [] # self.stages = keras.Sequential() for (dim_in, dim_out), num_layers, ff_expansion, heads, reduction_ratio in \ zip(dim_pairs, num_layers, ff_expansion, heads, reduction_ratio): # get_overlap_patches = nn.Unfold(kernel, stride = stride, padding = padding) # get_overlap_patches = my_Unfold(kernel = kernel, stride=stride, padding='SAME') overlap_patch_embed = keras.layers.Conv2D( dim_out, 1) # 改动 : 可将卷积核变为3,提取较大的局部信息 layers = [] # layers = keras.Sequential() # for _ in range(num_layers): for _ in range(4): layers.append([ PreNorm(dim_out, EfficientSelfAttention(dim = dim_out, heads = heads, reduction_ratio = reduction_ratio)), PreNorm(dim_out, MixFeedForward(dim = dim_out, expansion_factor = ff_expansion)), ]) self.stages.append([ # get_overlap_patches, overlap_patch_embed, layers ]) def call(self,x,return_layer_outputs = False): layer_outputs = [] layer_index = 0 # for (get_overlap_patches, overlap_embed, layers) in self.stages: for (overlap_embed, layers) in self.stages: ksize,stride = self.stage_kernel_stride_pad[layer_index] x = tf.image.extract_patches(images=x,sizes=[1,ksize,ksize,1], strides=[1,stride,stride,1],rates=[1,1,1,1],padding='SAME') # num_patches = x.shape[-1] # ratio = int(sqrt((h * w) / num_patches)) # x = rearrange(x, 'b c (h w) -> b c h w', h = h // ratio) x = overlap_embed(x) for (attn, ff) in layers: x = attn(x) + x x = ff(x) + x layer_outputs.append(x) layer_index+=1 ret = x if not return_layer_outputs else layer_outputs return ret # @tf.function class Segformer(keras.Model): def __init__(self,dims = (32, 64, 160, 256),heads = (1, 2, 5, 8),ff_expansion = (8, 8, 4, 4), reduction_ratio = (16, 4, 2, 1),num_layers = 2,channels = 3,decoder_dim = 256,num_classes = 4,stage_kernel_stride_pad = ((3, 2), (3, 2), (3, 2), (3, 2))): super().__init__() dims, heads, ff_expansion, reduction_ratio, num_layers = map( partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers) ) self.upscale = stage_kernel_stride_pad[0][1] assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values' self.mit = MiT( channels = channels, dims = dims, heads = heads, ff_expansion = ff_expansion, reduction_ratio = reduction_ratio, num_layers = num_layers, stage_kernel_stride_pad = stage_kernel_stride_pad ) self.to_fused = [ keras.Sequential([ keras.layers.Conv2D(decoder_dim, 1), # nn.Upsample(scale_factor=2 ** i) keras.layers.UpSampling2D(size=(2**i,2**i)) ]) for i, dim in enumerate(dims) ] self.to_segmentation = keras.Sequential([ keras.layers.Conv2D(decoder_dim, 1), # 改进: 激活函数 keras.layers.BatchNormalization(), keras.layers.ReLU(), # keras.layers.Conv2DTranspose(filters=decoder_dim//2,kernel_size = [self.upscale+2,self.upscale+2],strides=(self.upscale, self.upscale),padding='SAME'), keras.layers.UpSampling2D(size=(self.upscale,self.upscale)), # keras.layers.Conv2D(decoder_dim, 3,1, activation=tf.nn.relu), keras.layers.Conv2D(num_classes, 1,activation=tf.nn.sigmoid), ]) @tf.function def call(self, x): layer_outputs = self.mit(x, return_layer_outputs = True) # fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)] fused = [] for i_stage in range(len(self.to_fused)): output = layer_outputs[i_stage] to_fused = self.to_fused[i_stage] fused.append(to_fused(output)) fused = tf.concat(fused, axis = -1) logits = self.to_segmentation(fused) # softmax_out = tf.nn.softmax(logits,axis=-1) return logits ############## 测试 ################################ # if __name__ == '__main__': # s_model = Segformer() # img_path = '1.PNG' # img = cv2.imread(img_path, 1) # # input_data = np.reshape(img,(1,256,256,3)).astype(np.float32) # # c = s_model(input_data) # # print('input_data:',input_data.shape) # print('c:',c.shape)