Loading [Contrib]/a11y/accessibility-menu.js

alex_bn_lee

导航

【600】Attention U-Net 解释

参考:Attention-UNet for Pneumothorax Segmentation 

参考:Attention U-Net


一、Model 结构图

  说明:这是3D的数据,F代表 feature( channel),H 代表 height, W 代表 width, D代表 depth,就是3D数据块的深度。对于普通的图片数据可以删除掉 D,另外就是会把通道放后面,因此可以表示为 $H_1 \times W_1 \times F_1$。

二、AttnBlock2D 函数的图示

  下图为 AttnBlock2D 函数的实现效果,输出结果相当于 U-Net skip connection 的连接 layer,后面需要接一个 Concatenation

  以上为 Attention Gate 的原始结构图,可以按照下面的结构图进行理解:

  • 输入为 $x$(最上 conv2d_126,分成两个线路)和 $g$(左边 up_sampling_2d_11)

  • $x$ 经过一个卷积、$g$ 经过一个卷积,然后两者做个加法

    • $x$ 经过一个卷积的 通道 数量为 x.channels // 4
    • $g$ 经过一个卷积的 通道 数量为 x.channels // 4
  • 之后连续的 ReLU、卷积、Sigmod,得到权重图片,如下图的 activation_19

    • 卷积的 通道 数量为 1,可以之后进行相乘,Attention
  • 最后将 activation_19 与 $x$(最上 conv2d_126) 进行相乘,就完成了整个过程

  实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from keras import Input
from keras.layers import Conv2D, Activation, UpSampling2D, Lambda, Dropout, MaxPooling2D, multiply, add
from keras import backend as K
from keras.models import Model
 
IMG_CHANNEL = 3
 
def AttnBlock2D(x, g, inter_channel, data_format='channels_first'):
 
    theta_x = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(x)
 
    phi_g = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(g)
 
    f = Activation('relu')(add([theta_x, phi_g]))
 
    psi_f = Conv2D(1, [1, 1], strides=[1, 1], data_format=data_format)(f)
 
    rate = Activation('sigmoid')(psi_f)
 
    att_x = multiply([x, rate])
 
    return att_x
 
 
def attention_up_and_concate(down_layer, layer, data_format='channels_first'):
     
    if data_format == 'channels_first':
        in_channel = down_layer.get_shape().as_list()[1]
    else:
        in_channel = down_layer.get_shape().as_list()[3]
     
    up = UpSampling2D(size=(2, 2), data_format=data_format)(down_layer)
    layer = AttnBlock2D(x=layer, g=up, inter_channel=in_channel // 4, data_format=data_format)
 
    if data_format == 'channels_first':
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=1))
    else:
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))  # 参考代码这个地方写错了,x[1] 写成了 x[3]
     
    concate = my_concat([up, layer])
    return concate
 
# Attention U-Net
def att_unet(img_w, img_h, n_label, data_format='channels_first'):
    # inputs = (3, 160, 160)
    inputs = Input((IMG_CHANNEL, img_w, img_h))
    x = inputs
    depth = 4
    features = 32
    skips = []
    # depth = 0, 1, 2, 3
    for i in range(depth):
        # ENCODER
        x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
        x = Dropout(0.2)(x)
        x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
        skips.append(x)
        x = MaxPooling2D((2, 2), data_format='channels_first')(x)
        features = features * 2
 
    # BOTTLENECK
    x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
    x = Dropout(0.2)(x)
    x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
 
    # DECODER
    for i in reversed(range(depth)):
        features = features // 2
        x = attention_up_and_concate(x, skips[i], data_format=data_format)
        x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
        x = Dropout(0.2)(x)
        x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
     
    conv6 = Conv2D(n_label, (1, 1), padding='same', data_format=data_format)(x)
    conv7 = Activation('sigmoid')(conv6)
     
    model = Model(inputs=inputs, outputs=conv7)
 
    return model
 
IMG_WIDTH = 160
IMG_HEIGHT = 160
 
model = att_unet(IMG_WIDTH, IMG_HEIGHT, n_label=1)
model.summary()
 
from keras.utils.vis_utils import plot_model
plot_model(model, to_file='Att_U_Net.png', show_shapes=True)

   输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                    
==================================================================================================
input_11 (InputLayer)           [(None, 3, 160, 160) 0                                           
__________________________________________________________________________________________________
conv2d_119 (Conv2D)             (None, 32, 160, 160) 896         input_11[0][0]                  
__________________________________________________________________________________________________
dropout_45 (Dropout)            (None, 32, 160, 160) 0           conv2d_119[0][0]                
__________________________________________________________________________________________________
conv2d_120 (Conv2D)             (None, 32, 160, 160) 9248        dropout_45[0][0]                
__________________________________________________________________________________________________
max_pooling2d_32 (MaxPooling2D) (None, 32, 80, 80)   0           conv2d_120[0][0]                
__________________________________________________________________________________________________
conv2d_121 (Conv2D)             (None, 64, 80, 80)   18496       max_pooling2d_32[0][0]          
__________________________________________________________________________________________________
dropout_46 (Dropout)            (None, 64, 80, 80)   0           conv2d_121[0][0]                
__________________________________________________________________________________________________
conv2d_122 (Conv2D)             (None, 64, 80, 80)   36928       dropout_46[0][0]                
__________________________________________________________________________________________________
max_pooling2d_33 (MaxPooling2D) (None, 64, 40, 40)   0           conv2d_122[0][0]                
__________________________________________________________________________________________________
conv2d_123 (Conv2D)             (None, 128, 40, 4073856       max_pooling2d_33[0][0]          
__________________________________________________________________________________________________
dropout_47 (Dropout)            (None, 128, 40, 400           conv2d_123[0][0]                
__________________________________________________________________________________________________
conv2d_124 (Conv2D)             (None, 128, 40, 40147584      dropout_47[0][0]                
__________________________________________________________________________________________________
max_pooling2d_34 (MaxPooling2D) (None, 128, 20, 200           conv2d_124[0][0]                
__________________________________________________________________________________________________
conv2d_125 (Conv2D)             (None, 256, 20, 20295168      max_pooling2d_34[0][0]          
__________________________________________________________________________________________________
dropout_48 (Dropout)            (None, 256, 20, 200           conv2d_125[0][0]                
__________________________________________________________________________________________________
conv2d_126 (Conv2D)             (None, 256, 20, 20590080      dropout_48[0][0]                
__________________________________________________________________________________________________
max_pooling2d_35 (MaxPooling2D) (None, 256, 10, 100           conv2d_126[0][0]                
__________________________________________________________________________________________________
conv2d_127 (Conv2D)             (None, 512, 10, 101180160     max_pooling2d_35[0][0]          
__________________________________________________________________________________________________
dropout_49 (Dropout)            (None, 512, 10, 100           conv2d_127[0][0]                
__________________________________________________________________________________________________
conv2d_128 (Conv2D)             (None, 512, 10, 102359808     dropout_49[0][0]                
__________________________________________________________________________________________________
up_sampling2d_11 (UpSampling2D) (None, 512, 20, 200           conv2d_128[0][0]                
__________________________________________________________________________________________________
conv2d_129 (Conv2D)             (None, 128, 20, 2032896       conv2d_126[0][0]                
__________________________________________________________________________________________________
conv2d_130 (Conv2D)             (None, 128, 20, 2065664       up_sampling2d_11[0][0]          
__________________________________________________________________________________________________
add_6 (Add)                     (None, 128, 20, 200           conv2d_129[0][0]                
                                                                 conv2d_130[0][0]                
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 128, 20, 200           add_6[0][0]                     
__________________________________________________________________________________________________
conv2d_131 (Conv2D)             (None, 1, 20, 20)    129         activation_18[0][0]             
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 1, 20, 20)    0           conv2d_131[0][0]                
__________________________________________________________________________________________________
multiply_6 (Multiply)           (None, 256, 20, 200           conv2d_126[0][0]                
                                                                 activation_19[0][0]             
__________________________________________________________________________________________________
lambda_5 (Lambda)               (None, 768, 20, 200           up_sampling2d_11[0][0]          
                                                                 multiply_6[0][0]                
__________________________________________________________________________________________________
conv2d_132 (Conv2D)             (None, 256, 20, 201769728     lambda_5[0][0]                  
__________________________________________________________________________________________________
dropout_50 (Dropout)            (None, 256, 20, 200           conv2d_132[0][0]                
__________________________________________________________________________________________________
conv2d_133 (Conv2D)             (None, 256, 20, 20590080      dropout_50[0][0]                
__________________________________________________________________________________________________
up_sampling2d_12 (UpSampling2D) (None, 256, 40, 400           conv2d_133[0][0]                
__________________________________________________________________________________________________
conv2d_134 (Conv2D)             (None, 64, 40, 40)   8256        conv2d_124[0][0]                
__________________________________________________________________________________________________
conv2d_135 (Conv2D)             (None, 64, 40, 40)   16448       up_sampling2d_12[0][0]          
__________________________________________________________________________________________________
add_7 (Add)                     (None, 64, 40, 40)   0           conv2d_134[0][0]                
                                                                 conv2d_135[0][0]                
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 64, 40, 40)   0           add_7[0][0]                     
__________________________________________________________________________________________________
conv2d_136 (Conv2D)             (None, 1, 40, 40)    65          activation_20[0][0]             
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 1, 40, 40)    0           conv2d_136[0][0]                
__________________________________________________________________________________________________
multiply_7 (Multiply)           (None, 128, 40, 400           conv2d_124[0][0]                
                                                                 activation_21[0][0]             
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, 384, 40, 400           up_sampling2d_12[0][0]          
                                                                 multiply_7[0][0]                
__________________________________________________________________________________________________
conv2d_137 (Conv2D)             (None, 128, 40, 40442496      lambda_6[0][0]                  
__________________________________________________________________________________________________
dropout_51 (Dropout)            (None, 128, 40, 400           conv2d_137[0][0]                
__________________________________________________________________________________________________
conv2d_138 (Conv2D)             (None, 128, 40, 40147584      dropout_51[0][0]                
__________________________________________________________________________________________________
up_sampling2d_13 (UpSampling2D) (None, 128, 80, 800           conv2d_138[0][0]                
__________________________________________________________________________________________________
conv2d_139 (Conv2D)             (None, 32, 80, 80)   2080        conv2d_122[0][0]                
__________________________________________________________________________________________________
conv2d_140 (Conv2D)             (None, 32, 80, 80)   4128        up_sampling2d_13[0][0]          
__________________________________________________________________________________________________
add_8 (Add)                     (None, 32, 80, 80)   0           conv2d_139[0][0]                
                                                                 conv2d_140[0][0]                
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 32, 80, 80)   0           add_8[0][0]                     
__________________________________________________________________________________________________
conv2d_141 (Conv2D)             (None, 1, 80, 80)    33          activation_22[0][0]             
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 1, 80, 80)    0           conv2d_141[0][0]                
__________________________________________________________________________________________________
multiply_8 (Multiply)           (None, 64, 80, 80)   0           conv2d_122[0][0]                
                                                                 activation_23[0][0]             
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, 192, 80, 800           up_sampling2d_13[0][0]          
                                                                 multiply_8[0][0]                
__________________________________________________________________________________________________
conv2d_142 (Conv2D)             (None, 64, 80, 80)   110656      lambda_7[0][0]                  
__________________________________________________________________________________________________
dropout_52 (Dropout)            (None, 64, 80, 80)   0           conv2d_142[0][0]                
__________________________________________________________________________________________________
conv2d_143 (Conv2D)             (None, 64, 80, 80)   36928       dropout_52[0][0]                
__________________________________________________________________________________________________
up_sampling2d_14 (UpSampling2D) (None, 64, 160, 160) 0           conv2d_143[0][0]                
__________________________________________________________________________________________________
conv2d_144 (Conv2D)             (None, 16, 160, 160) 528         conv2d_120[0][0]                
__________________________________________________________________________________________________
conv2d_145 (Conv2D)             (None, 16, 160, 160) 1040        up_sampling2d_14[0][0]          
__________________________________________________________________________________________________
add_9 (Add)                     (None, 16, 160, 160) 0           conv2d_144[0][0]                
                                                                 conv2d_145[0][0]                
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 16, 160, 160) 0           add_9[0][0]                     
__________________________________________________________________________________________________
conv2d_146 (Conv2D)             (None, 1, 160, 16017          activation_24[0][0]             
__________________________________________________________________________________________________
activation_25 (Activation)      (None, 1, 160, 1600           conv2d_146[0][0]                
__________________________________________________________________________________________________
multiply_9 (Multiply)           (None, 32, 160, 160) 0           conv2d_120[0][0]                
                                                                 activation_25[0][0]             
__________________________________________________________________________________________________
lambda_8 (Lambda)               (None, 96, 160, 160) 0           up_sampling2d_14[0][0]          
                                                                 multiply_9[0][0]                
__________________________________________________________________________________________________
conv2d_147 (Conv2D)             (None, 32, 160, 160) 27680       lambda_8[0][0]                  
__________________________________________________________________________________________________
dropout_53 (Dropout)            (None, 32, 160, 160) 0           conv2d_147[0][0]                
__________________________________________________________________________________________________
conv2d_148 (Conv2D)             (None, 32, 160, 160) 9248        dropout_53[0][0]                
__________________________________________________________________________________________________
conv2d_149 (Conv2D)             (None, 1, 160, 16033          conv2d_148[0][0]                
__________________________________________________________________________________________________
activation_26 (Activation)      (None, 1, 160, 1600           conv2d_149[0][0]                
==================================================================================================
Total params: 7,977,941
Trainable params: 7,977,941
Non-trainable params: 0
__________________________________________________________________________________________________

   结构图如下:

  针对通道在最后的代码补充:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from keras import Input
from keras.layers import Conv2D, Activation, UpSampling2D, Lambda, Dropout, MaxPooling2D, multiply, add
from keras import backend as K
from keras.models import Model
 
IMG_CHANNEL = 3
 
def AttnBlock2D(x, g, inter_channel):
    # x: skip connection layer
    # g: down layer upsampling 后的 layer
    # inner_channel: down layer 的通道数 // 4
     
    theta_x = Conv2D(inter_channel, [1, 1], strides=[1, 1])(x)
    phi_g = Conv2D(inter_channel, [1, 1], strides=[1, 1])(g)
    f = Activation('relu')(add([theta_x, phi_g]))
    psi_f = Conv2D(1, [1, 1], strides=[1, 1])(f)
    rate = Activation('sigmoid')(psi_f)
    att_x = multiply([x, rate])
 
    return att_x
 
def attention_up_and_concate(down_layer, layer):
    # down_layer: 承接下来的 layer
    # layer: skip connection layer
     
    in_channel = down_layer.get_shape().as_list()[3]
    up = UpSampling2D(size=(2, 2))(down_layer)
    layer = AttnBlock2D(x=layer, g=up, inter_channel=in_channel // 4)
    my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))
    concate = my_concat([up, layer])
     
    return concate
 
# Attention U-Net
def att_unet(img_w, img_h, n_label):
    inputs = Input((img_w, img_h, IMG_CHANNEL))
    x = inputs
    depth = 4
    features = 32
    skips = []
     
    # depth = 0, 1, 2, 3
    # ENCODER
    for i in range(depth):
        x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
        x = Dropout(0.2)(x)
        x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
        skips.append(x)
        x = MaxPooling2D((2, 2))(x)
        features = features * 2
 
    # BOTTLENECK
    x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
    x = Dropout(0.2)(x)
    x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
 
    # DECODER
    for i in reversed(range(depth)):
        features = features // 2
        x = attention_up_and_concate(x, skips[i])
        x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
        x = Dropout(0.2)(x)
        x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
     
    conv6 = Conv2D(n_label, (1, 1), padding='same')(x)
    conv7 = Activation('sigmoid')(conv6)
     
    model = Model(inputs=inputs, outputs=conv7)
 
    return model
 
IMG_WIDTH = 160
IMG_HEIGHT = 160
 
model = att_unet(IMG_WIDTH, IMG_HEIGHT, n_label=1)
model.summary()

 

posted on   McDelfino  阅读(2141)  评论(0编辑  收藏  举报

努力加载评论中...
点击右上角即可分享
微信分享提示