【600】Attention U-Net 解释
参考:Attention-UNet for Pneumothorax Segmentation
一、Model 结构图
说明:这是3D的数据,F代表 feature( channel),H 代表 height, W 代表 width, D代表 depth,就是3D数据块的深度。对于普通的图片数据可以删除掉 D,另外就是会把通道放后面,因此可以表示为 $H_1 \times W_1 \times F_1$。
二、AttnBlock2D 函数的图示
下图为 AttnBlock2D 函数的实现效果,输出结果相当于 U-Net skip connection 的连接 layer,后面需要接一个 Concatenation
以上为 Attention Gate 的原始结构图,可以按照下面的结构图进行理解:
-
输入为 $x$(最上 conv2d_126,分成两个线路)和 $g$(左边 up_sampling_2d_11)
-
$x$ 经过一个卷积、$g$ 经过一个卷积,然后两者做个加法
- $x$ 经过一个卷积的 通道 数量为 x.channels // 4
- $g$ 经过一个卷积的 通道 数量为 x.channels // 4
-
之后连续的 ReLU、卷积、Sigmod,得到权重图片,如下图的 activation_19
- 卷积的 通道 数量为 1,可以之后进行相乘,Attention
-
最后将 activation_19 与 $x$(最上 conv2d_126) 进行相乘,就完成了整个过程
实现代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | from keras import Input from keras.layers import Conv2D, Activation, UpSampling2D, Lambda, Dropout, MaxPooling2D, multiply, add from keras import backend as K from keras.models import Model IMG_CHANNEL = 3 def AttnBlock2D(x, g, inter_channel, data_format = 'channels_first' ): theta_x = Conv2D(inter_channel, [ 1 , 1 ], strides = [ 1 , 1 ], data_format = data_format)(x) phi_g = Conv2D(inter_channel, [ 1 , 1 ], strides = [ 1 , 1 ], data_format = data_format)(g) f = Activation( 'relu' )(add([theta_x, phi_g])) psi_f = Conv2D( 1 , [ 1 , 1 ], strides = [ 1 , 1 ], data_format = data_format)(f) rate = Activation( 'sigmoid' )(psi_f) att_x = multiply([x, rate]) return att_x def attention_up_and_concate(down_layer, layer, data_format = 'channels_first' ): if data_format = = 'channels_first' : in_channel = down_layer.get_shape().as_list()[ 1 ] else : in_channel = down_layer.get_shape().as_list()[ 3 ] up = UpSampling2D(size = ( 2 , 2 ), data_format = data_format)(down_layer) layer = AttnBlock2D(x = layer, g = up, inter_channel = in_channel / / 4 , data_format = data_format) if data_format = = 'channels_first' : my_concat = Lambda( lambda x: K.concatenate([x[ 0 ], x[ 1 ]], axis = 1 )) else : my_concat = Lambda( lambda x: K.concatenate([x[ 0 ], x[ 1 ]], axis = 3 )) # 参考代码这个地方写错了,x[1] 写成了 x[3] concate = my_concat([up, layer]) return concate # Attention U-Net def att_unet(img_w, img_h, n_label, data_format = 'channels_first' ): # inputs = (3, 160, 160) inputs = Input ((IMG_CHANNEL, img_w, img_h)) x = inputs depth = 4 features = 32 skips = [] # depth = 0, 1, 2, 3 for i in range (depth): # ENCODER x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' , data_format = data_format)(x) x = Dropout( 0.2 )(x) x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' , data_format = data_format)(x) skips.append(x) x = MaxPooling2D(( 2 , 2 ), data_format = 'channels_first' )(x) features = features * 2 # BOTTLENECK x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' , data_format = data_format)(x) x = Dropout( 0.2 )(x) x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' , data_format = data_format)(x) # DECODER for i in reversed ( range (depth)): features = features / / 2 x = attention_up_and_concate(x, skips[i], data_format = data_format) x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' , data_format = data_format)(x) x = Dropout( 0.2 )(x) x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' , data_format = data_format)(x) conv6 = Conv2D(n_label, ( 1 , 1 ), padding = 'same' , data_format = data_format)(x) conv7 = Activation( 'sigmoid' )(conv6) model = Model(inputs = inputs, outputs = conv7) return model IMG_WIDTH = 160 IMG_HEIGHT = 160 model = att_unet(IMG_WIDTH, IMG_HEIGHT, n_label = 1 ) model.summary() from keras.utils.vis_utils import plot_model plot_model(model, to_file = 'Att_U_Net.png' , show_shapes = True ) |
输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | Model: "model" __________________________________________________________________________________________________ Layer ( type ) Output Shape Param # Connected to = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = input_11 (InputLayer) [( None , 3 , 160 , 160 ) 0 __________________________________________________________________________________________________ conv2d_119 (Conv2D) ( None , 32 , 160 , 160 ) 896 input_11[ 0 ][ 0 ] __________________________________________________________________________________________________ dropout_45 (Dropout) ( None , 32 , 160 , 160 ) 0 conv2d_119[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_120 (Conv2D) ( None , 32 , 160 , 160 ) 9248 dropout_45[ 0 ][ 0 ] __________________________________________________________________________________________________ max_pooling2d_32 (MaxPooling2D) ( None , 32 , 80 , 80 ) 0 conv2d_120[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_121 (Conv2D) ( None , 64 , 80 , 80 ) 18496 max_pooling2d_32[ 0 ][ 0 ] __________________________________________________________________________________________________ dropout_46 (Dropout) ( None , 64 , 80 , 80 ) 0 conv2d_121[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_122 (Conv2D) ( None , 64 , 80 , 80 ) 36928 dropout_46[ 0 ][ 0 ] __________________________________________________________________________________________________ max_pooling2d_33 (MaxPooling2D) ( None , 64 , 40 , 40 ) 0 conv2d_122[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_123 (Conv2D) ( None , 128 , 40 , 40 ) 73856 max_pooling2d_33[ 0 ][ 0 ] __________________________________________________________________________________________________ dropout_47 (Dropout) ( None , 128 , 40 , 40 ) 0 conv2d_123[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_124 (Conv2D) ( None , 128 , 40 , 40 ) 147584 dropout_47[ 0 ][ 0 ] __________________________________________________________________________________________________ max_pooling2d_34 (MaxPooling2D) ( None , 128 , 20 , 20 ) 0 conv2d_124[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_125 (Conv2D) ( None , 256 , 20 , 20 ) 295168 max_pooling2d_34[ 0 ][ 0 ] __________________________________________________________________________________________________ dropout_48 (Dropout) ( None , 256 , 20 , 20 ) 0 conv2d_125[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_126 (Conv2D) ( None , 256 , 20 , 20 ) 590080 dropout_48[ 0 ][ 0 ] __________________________________________________________________________________________________ max_pooling2d_35 (MaxPooling2D) ( None , 256 , 10 , 10 ) 0 conv2d_126[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_127 (Conv2D) ( None , 512 , 10 , 10 ) 1180160 max_pooling2d_35[ 0 ][ 0 ] __________________________________________________________________________________________________ dropout_49 (Dropout) ( None , 512 , 10 , 10 ) 0 conv2d_127[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_128 (Conv2D) ( None , 512 , 10 , 10 ) 2359808 dropout_49[ 0 ][ 0 ] __________________________________________________________________________________________________ up_sampling2d_11 (UpSampling2D) ( None , 512 , 20 , 20 ) 0 conv2d_128[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_129 (Conv2D) ( None , 128 , 20 , 20 ) 32896 conv2d_126[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_130 (Conv2D) ( None , 128 , 20 , 20 ) 65664 up_sampling2d_11[ 0 ][ 0 ] __________________________________________________________________________________________________ add_6 (Add) ( None , 128 , 20 , 20 ) 0 conv2d_129[ 0 ][ 0 ] conv2d_130[ 0 ][ 0 ] __________________________________________________________________________________________________ activation_18 (Activation) ( None , 128 , 20 , 20 ) 0 add_6[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_131 (Conv2D) ( None , 1 , 20 , 20 ) 129 activation_18[ 0 ][ 0 ] __________________________________________________________________________________________________ activation_19 (Activation) ( None , 1 , 20 , 20 ) 0 conv2d_131[ 0 ][ 0 ] __________________________________________________________________________________________________ multiply_6 (Multiply) ( None , 256 , 20 , 20 ) 0 conv2d_126[ 0 ][ 0 ] activation_19[ 0 ][ 0 ] __________________________________________________________________________________________________ lambda_5 (Lambda) ( None , 768 , 20 , 20 ) 0 up_sampling2d_11[ 0 ][ 0 ] multiply_6[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_132 (Conv2D) ( None , 256 , 20 , 20 ) 1769728 lambda_5[ 0 ][ 0 ] __________________________________________________________________________________________________ dropout_50 (Dropout) ( None , 256 , 20 , 20 ) 0 conv2d_132[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_133 (Conv2D) ( None , 256 , 20 , 20 ) 590080 dropout_50[ 0 ][ 0 ] __________________________________________________________________________________________________ up_sampling2d_12 (UpSampling2D) ( None , 256 , 40 , 40 ) 0 conv2d_133[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_134 (Conv2D) ( None , 64 , 40 , 40 ) 8256 conv2d_124[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_135 (Conv2D) ( None , 64 , 40 , 40 ) 16448 up_sampling2d_12[ 0 ][ 0 ] __________________________________________________________________________________________________ add_7 (Add) ( None , 64 , 40 , 40 ) 0 conv2d_134[ 0 ][ 0 ] conv2d_135[ 0 ][ 0 ] __________________________________________________________________________________________________ activation_20 (Activation) ( None , 64 , 40 , 40 ) 0 add_7[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_136 (Conv2D) ( None , 1 , 40 , 40 ) 65 activation_20[ 0 ][ 0 ] __________________________________________________________________________________________________ activation_21 (Activation) ( None , 1 , 40 , 40 ) 0 conv2d_136[ 0 ][ 0 ] __________________________________________________________________________________________________ multiply_7 (Multiply) ( None , 128 , 40 , 40 ) 0 conv2d_124[ 0 ][ 0 ] activation_21[ 0 ][ 0 ] __________________________________________________________________________________________________ lambda_6 (Lambda) ( None , 384 , 40 , 40 ) 0 up_sampling2d_12[ 0 ][ 0 ] multiply_7[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_137 (Conv2D) ( None , 128 , 40 , 40 ) 442496 lambda_6[ 0 ][ 0 ] __________________________________________________________________________________________________ dropout_51 (Dropout) ( None , 128 , 40 , 40 ) 0 conv2d_137[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_138 (Conv2D) ( None , 128 , 40 , 40 ) 147584 dropout_51[ 0 ][ 0 ] __________________________________________________________________________________________________ up_sampling2d_13 (UpSampling2D) ( None , 128 , 80 , 80 ) 0 conv2d_138[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_139 (Conv2D) ( None , 32 , 80 , 80 ) 2080 conv2d_122[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_140 (Conv2D) ( None , 32 , 80 , 80 ) 4128 up_sampling2d_13[ 0 ][ 0 ] __________________________________________________________________________________________________ add_8 (Add) ( None , 32 , 80 , 80 ) 0 conv2d_139[ 0 ][ 0 ] conv2d_140[ 0 ][ 0 ] __________________________________________________________________________________________________ activation_22 (Activation) ( None , 32 , 80 , 80 ) 0 add_8[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_141 (Conv2D) ( None , 1 , 80 , 80 ) 33 activation_22[ 0 ][ 0 ] __________________________________________________________________________________________________ activation_23 (Activation) ( None , 1 , 80 , 80 ) 0 conv2d_141[ 0 ][ 0 ] __________________________________________________________________________________________________ multiply_8 (Multiply) ( None , 64 , 80 , 80 ) 0 conv2d_122[ 0 ][ 0 ] activation_23[ 0 ][ 0 ] __________________________________________________________________________________________________ lambda_7 (Lambda) ( None , 192 , 80 , 80 ) 0 up_sampling2d_13[ 0 ][ 0 ] multiply_8[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_142 (Conv2D) ( None , 64 , 80 , 80 ) 110656 lambda_7[ 0 ][ 0 ] __________________________________________________________________________________________________ dropout_52 (Dropout) ( None , 64 , 80 , 80 ) 0 conv2d_142[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_143 (Conv2D) ( None , 64 , 80 , 80 ) 36928 dropout_52[ 0 ][ 0 ] __________________________________________________________________________________________________ up_sampling2d_14 (UpSampling2D) ( None , 64 , 160 , 160 ) 0 conv2d_143[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_144 (Conv2D) ( None , 16 , 160 , 160 ) 528 conv2d_120[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_145 (Conv2D) ( None , 16 , 160 , 160 ) 1040 up_sampling2d_14[ 0 ][ 0 ] __________________________________________________________________________________________________ add_9 (Add) ( None , 16 , 160 , 160 ) 0 conv2d_144[ 0 ][ 0 ] conv2d_145[ 0 ][ 0 ] __________________________________________________________________________________________________ activation_24 (Activation) ( None , 16 , 160 , 160 ) 0 add_9[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_146 (Conv2D) ( None , 1 , 160 , 160 ) 17 activation_24[ 0 ][ 0 ] __________________________________________________________________________________________________ activation_25 (Activation) ( None , 1 , 160 , 160 ) 0 conv2d_146[ 0 ][ 0 ] __________________________________________________________________________________________________ multiply_9 (Multiply) ( None , 32 , 160 , 160 ) 0 conv2d_120[ 0 ][ 0 ] activation_25[ 0 ][ 0 ] __________________________________________________________________________________________________ lambda_8 (Lambda) ( None , 96 , 160 , 160 ) 0 up_sampling2d_14[ 0 ][ 0 ] multiply_9[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_147 (Conv2D) ( None , 32 , 160 , 160 ) 27680 lambda_8[ 0 ][ 0 ] __________________________________________________________________________________________________ dropout_53 (Dropout) ( None , 32 , 160 , 160 ) 0 conv2d_147[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_148 (Conv2D) ( None , 32 , 160 , 160 ) 9248 dropout_53[ 0 ][ 0 ] __________________________________________________________________________________________________ conv2d_149 (Conv2D) ( None , 1 , 160 , 160 ) 33 conv2d_148[ 0 ][ 0 ] __________________________________________________________________________________________________ activation_26 (Activation) ( None , 1 , 160 , 160 ) 0 conv2d_149[ 0 ][ 0 ] = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = Total params: 7 , 977 , 941 Trainable params: 7 , 977 , 941 Non - trainable params: 0 __________________________________________________________________________________________________ |
结构图如下:
针对通道在最后的代码补充:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | from keras import Input from keras.layers import Conv2D, Activation, UpSampling2D, Lambda, Dropout, MaxPooling2D, multiply, add from keras import backend as K from keras.models import Model IMG_CHANNEL = 3 def AttnBlock2D(x, g, inter_channel): # x: skip connection layer # g: down layer upsampling 后的 layer # inner_channel: down layer 的通道数 // 4 theta_x = Conv2D(inter_channel, [ 1 , 1 ], strides = [ 1 , 1 ])(x) phi_g = Conv2D(inter_channel, [ 1 , 1 ], strides = [ 1 , 1 ])(g) f = Activation( 'relu' )(add([theta_x, phi_g])) psi_f = Conv2D( 1 , [ 1 , 1 ], strides = [ 1 , 1 ])(f) rate = Activation( 'sigmoid' )(psi_f) att_x = multiply([x, rate]) return att_x def attention_up_and_concate(down_layer, layer): # down_layer: 承接下来的 layer # layer: skip connection layer in_channel = down_layer.get_shape().as_list()[ 3 ] up = UpSampling2D(size = ( 2 , 2 ))(down_layer) layer = AttnBlock2D(x = layer, g = up, inter_channel = in_channel / / 4 ) my_concat = Lambda( lambda x: K.concatenate([x[ 0 ], x[ 1 ]], axis = 3 )) concate = my_concat([up, layer]) return concate # Attention U-Net def att_unet(img_w, img_h, n_label): inputs = Input ((img_w, img_h, IMG_CHANNEL)) x = inputs depth = 4 features = 32 skips = [] # depth = 0, 1, 2, 3 # ENCODER for i in range (depth): x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' )(x) x = Dropout( 0.2 )(x) x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' )(x) skips.append(x) x = MaxPooling2D(( 2 , 2 ))(x) features = features * 2 # BOTTLENECK x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' )(x) x = Dropout( 0.2 )(x) x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' )(x) # DECODER for i in reversed ( range (depth)): features = features / / 2 x = attention_up_and_concate(x, skips[i]) x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' )(x) x = Dropout( 0.2 )(x) x = Conv2D(features, ( 3 , 3 ), activation = 'relu' , padding = 'same' )(x) conv6 = Conv2D(n_label, ( 1 , 1 ), padding = 'same' )(x) conv7 = Activation( 'sigmoid' )(conv6) model = Model(inputs = inputs, outputs = conv7) return model IMG_WIDTH = 160 IMG_HEIGHT = 160 model = att_unet(IMG_WIDTH, IMG_HEIGHT, n_label = 1 ) model.summary() |
posted on 2021-07-06 20:55 McDelfino 阅读(2141) 评论(0) 编辑 收藏 举报
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步