数据噪声强?尝试一下深度残差收缩网络(附代码)
深度残差收缩网络是一种致力于从强噪声数据中学习特征的深度学习方法,是由“深度残差网络”和“收缩”两个部分组成的。
一方面,“深度残差网络”目前已经成为了深度学习领域的基础网络。
另一方面,“收缩”指的是软阈值化,是许多信号降噪算法的关键步骤。
更重要地,在深度残差收缩网络中,软阈值化所需要的阈值,实质上是在注意力机制下自动设置的,从而避免了人工设置阈值的麻烦。
在本文中,我们首先对残差网络、软阈值化和注意力机制的相关基础进行了简要的回顾,然后对深度残差收缩网络的动机、算法和应用展开解读。
1. 相关基础
1.1 残差网络
残差网络(或称为深度残差网络、深度残差学习,英文ResNet)属于一种卷积神经网络。相较于普通的卷积神经网络,残差网络采用了跨层的恒等连接,以减轻卷积神经网络的训练难度。残差网络的一种常见的基本模块如图1所示。
1.2 软阈值化
软阈值化是许多信号降噪方法的核心步骤。它的用处是将绝对值低于某个阈值的特征置为零,将其他的特征也朝着零进行调整,也就是收缩。在这里,阈值是一个需要预先设置的参数,其取值大小对于降噪的结果有着直接的影响。软阈值化的输入与输出之间的关系如图所示。
可以看出,软阈值化是一种非线性变换,有着与ReLU激活函数相近的性质:梯度要么是0,要么是1。因此,软阈值化也能够作为神经网络的激活函数。事实上,一些神经网络已经将软阈值函数作为激活函数进行了使用。
1.3 注意力机制
注意力机制就是将注意力集中于局部关键信息的机制,可以分成两步:第一,通过全局扫描,发现局部有用信息;第二,增强有用信息并抑制冗余信息。
Squeeze-and-Excitation Network是一种非常经典的注意力机制下的深度学习方法。它可以通过一个小型的子网络,自动学习得到一组权重,对特征图的各个通道进行加权。其含义在于,某些特征通道是较为重要的,而另一些特征通道是信息冗余的;那么,我们就可以通过这种方式增强有用特征通道、削弱冗余特征通道。Squeeze-and-Excitation Network的一种基本模块如下图所示。
值得指出的是,通过这种方式,每个样本都可以有自己独特的一组权重,可以根据样本自身的特点,进行独特的特征通道加权调整。例如,样本A的第一特征通道是重要的,第二特征通道是不重要的;而样本B的第一特征通道是不重要的,第二特征通道是重要的;通过这种方式,样本A可以有自己的一组权重,以加强第一特征通道,削弱第二特征通道;同样地,样本B可以有自己的一组权重,以削弱第一特征通道,加强第二特征通道。
2. 深度残差收缩网络理论
2.1 动机
首先,现实世界中的数据,或多或少都含有一些冗余信息。那么我们就可以尝试将软阈值化嵌入残差网络中,以进行冗余信息的消除。
其次,各个样本中冗余信息含量经常是不同的。那么我们就可以借助注意力机制,根据各个样本的情况,自适应地给各个样本设置不同的阈值。
2.2 算法
与残差网络和Squeeze-and-Excitation Network相似,深度残差收缩网络也是由许多基本模块堆叠而成的。每个基本模块都有一个子网络,用于自动学习得到一组阈值,用于特征图的软阈值化。值得指出的是,通过这种方式,每个样本都有着自己独特的一组阈值。深度残差收缩网络的一种基本模块如下图所示。
深度残差收缩网络的大致框架如下图所示,是由输入层、许多基本模块以及最后的全连接输出层等部分所组成的。
2.3 应用
在原始论文中,深度残差收缩网络是应用于基于振动信号的机械设备故障诊断。但是从原理上来讲,深度残差收缩网络面向的是数据集含有冗余信息的情况,而冗余信息是无处不在的。例如,在图像识别的时候,图像中总会包含一些与标签无关的区域;在语音识别的时候,音频中经常会含有各种形式的噪声。因此,深度残差收缩网络,或者说这种在深度学习算法内部集成“注意力机制”+“软阈值化”的思路,有着较为广泛的研究价值和应用前景。
3.Keras代码示例
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 Created on Sat Dec 28 23:24:05 2019 5 Implemented using TensorFlow 1.0.1 and Keras 2.2.1 6 7 M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 8 IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 9 @author: super_9527 10 """ 11 12 from __future__ import print_function 13 import keras 14 import numpy as np 15 from keras.datasets import mnist 16 from keras.layers import Dense, Conv2D, BatchNormalization, Activation 17 from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D 18 from keras.optimizers import Adam 19 from keras.regularizers import l2 20 from keras import backend as K 21 from keras.models import Model 22 from keras.layers.core import Lambda 23 K.set_learning_phase(1) 24 25 # Input image dimensions 26 img_rows, img_cols = 28, 28 27 28 # The data, split between train and test sets 29 (x_train, y_train), (x_test, y_test) = mnist.load_data() 30 31 if K.image_data_format() == 'channels_first': 32 x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 33 x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 34 input_shape = (1, img_rows, img_cols) 35 else: 36 x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 37 x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 38 input_shape = (img_rows, img_cols, 1) 39 40 # Noised data 41 x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1]) 42 x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1]) 43 print('x_train shape:', x_train.shape) 44 print(x_train.shape[0], 'train samples') 45 print(x_test.shape[0], 'test samples') 46 47 # convert class vectors to binary class matrices 48 y_train = keras.utils.to_categorical(y_train, 10) 49 y_test = keras.utils.to_categorical(y_test, 10) 50 51 52 def abs_backend(inputs): 53 return K.abs(inputs) 54 55 def expand_dim_backend(inputs): 56 return K.expand_dims(K.expand_dims(inputs,1),1) 57 58 def sign_backend(inputs): 59 return K.sign(inputs) 60 61 def pad_backend(inputs, in_channels, out_channels): 62 pad_dim = (out_channels - in_channels)//2 63 inputs = K.expand_dims(inputs,-1) 64 inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last') 65 return K.squeeze(inputs, -1) 66 67 # Residual Shrinakge Block 68 def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, 69 downsample_strides=2): 70 71 residual = incoming 72 in_channels = incoming.get_shape().as_list()[-1] 73 74 for i in range(nb_blocks): 75 76 identity = residual 77 78 if not downsample: 79 downsample_strides = 1 80 81 residual = BatchNormalization()(residual) 82 residual = Activation('relu')(residual) 83 residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), 84 padding='same', kernel_initializer='he_normal', 85 kernel_regularizer=l2(1e-4))(residual) 86 87 residual = BatchNormalization()(residual) 88 residual = Activation('relu')(residual) 89 residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', 90 kernel_regularizer=l2(1e-4))(residual) 91 92 # Calculate global means 93 residual_abs = Lambda(abs_backend)(residual) 94 abs_mean = GlobalAveragePooling2D()(residual_abs) 95 96 # Calculate scaling coefficients 97 scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', 98 kernel_regularizer=l2(1e-4))(abs_mean) 99 scales = BatchNormalization()(scales) 100 scales = Activation('relu')(scales) 101 scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales) 102 scales = Lambda(expand_dim_backend)(scales) 103 104 # Calculate thresholds 105 thres = keras.layers.multiply([abs_mean, scales]) 106 107 # Soft thresholding 108 sub = keras.layers.subtract([residual_abs, thres]) 109 zeros = keras.layers.subtract([sub, sub]) 110 n_sub = keras.layers.maximum([sub, zeros]) 111 residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub]) 112 113 # Downsampling (it is important to use the pooL-size of (1, 1)) 114 if downsample_strides > 1: 115 identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity) 116 117 # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution) 118 if in_channels != out_channels: 119 identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity) 120 121 residual = keras.layers.add([residual, identity]) 122 123 return residual 124 125 126 # define and train a model 127 inputs = Input(shape=input_shape) 128 net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs) 129 net = residual_shrinkage_block(net, 1, 8, downsample=True) 130 net = BatchNormalization()(net) 131 net = Activation('relu')(net) 132 net = GlobalAveragePooling2D()(net) 133 outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net) 134 model = Model(inputs=inputs, outputs=outputs) 135 model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) 136 model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test)) 137 138 # get results 139 K.set_learning_phase(0) 140 DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0) 141 print('Train loss:', DRSN_train_score[0]) 142 print('Train accuracy:', DRSN_train_score[1]) 143 DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0) 144 print('Test loss:', DRSN_test_score[0]) 145 print('Test accuracy:', DRSN_test_score[1])
4.TFLean代码示例
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 Created on Mon Dec 23 21:23:09 2019 5 Implemented using TensorFlow 1.0 and TFLearn 0.3.2 6 7 M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, 8 IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 9 10 @author: super_9527 11 """ 12 13 from __future__ import division, print_function, absolute_import 14 15 import tflearn 16 import numpy as np 17 import tensorflow as tf 18 from tflearn.layers.conv import conv_2d 19 20 # Data loading 21 from tflearn.datasets import cifar10 22 (X, Y), (testX, testY) = cifar10.load_data() 23 24 # Add noise 25 X = X + np.random.random((50000, 32, 32, 3))*0.1 26 testX = testX + np.random.random((10000, 32, 32, 3))*0.1 27 28 # Transform labels to one-hot format 29 Y = tflearn.data_utils.to_categorical(Y,10) 30 testY = tflearn.data_utils.to_categorical(testY,10) 31 32 def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, 33 downsample_strides=2, activation='relu', batch_norm=True, 34 bias=True, weights_init='variance_scaling', 35 bias_init='zeros', regularizer='L2', weight_decay=0.0001, 36 trainable=True, restore=True, reuse=False, scope=None, 37 name="ResidualBlock"): 38 39 # residual shrinkage blocks with channel-wise thresholds 40 41 residual = incoming 42 in_channels = incoming.get_shape().as_list()[-1] 43 44 # Variable Scope fix for older TF 45 try: 46 vscope = tf.variable_scope(scope, default_name=name, values=[incoming], 47 reuse=reuse) 48 except Exception: 49 vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse) 50 51 with vscope as scope: 52 name = scope.name #TODO 53 54 for i in range(nb_blocks): 55 56 identity = residual 57 58 if not downsample: 59 downsample_strides = 1 60 61 if batch_norm: 62 residual = tflearn.batch_normalization(residual) 63 residual = tflearn.activation(residual, activation) 64 residual = conv_2d(residual, out_channels, 3, 65 downsample_strides, 'same', 'linear', 66 bias, weights_init, bias_init, 67 regularizer, weight_decay, trainable, 68 restore) 69 70 if batch_norm: 71 residual = tflearn.batch_normalization(residual) 72 residual = tflearn.activation(residual, activation) 73 residual = conv_2d(residual, out_channels, 3, 1, 'same', 74 'linear', bias, weights_init, 75 bias_init, regularizer, weight_decay, 76 trainable, restore) 77 78 # get thresholds and apply thresholding 79 abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True) 80 scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') 81 scales = tflearn.batch_normalization(scales) 82 scales = tflearn.activation(scales, 'relu') 83 scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') 84 scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1) 85 thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales)) 86 # soft thresholding 87 residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0)) 88 89 90 # Downsampling 91 if downsample_strides > 1: 92 identity = tflearn.avg_pool_2d(identity, 1, 93 downsample_strides) 94 95 # Projection to new dimension 96 if in_channels != out_channels: 97 if (out_channels - in_channels) % 2 == 0: 98 ch = (out_channels - in_channels)//2 99 identity = tf.pad(identity, 100 [[0, 0], [0, 0], [0, 0], [ch, ch]]) 101 else: 102 ch = (out_channels - in_channels)//2 103 identity = tf.pad(identity, 104 [[0, 0], [0, 0], [0, 0], [ch, ch+1]]) 105 in_channels = out_channels 106 107 residual = residual + identity 108 109 return residual 110 111 112 # Real-time data preprocessing 113 img_prep = tflearn.ImagePreprocessing() 114 img_prep.add_featurewise_zero_center(per_channel=True) 115 116 # Real-time data augmentation 117 img_aug = tflearn.ImageAugmentation() 118 img_aug.add_random_flip_leftright() 119 img_aug.add_random_crop([32, 32], padding=4) 120 121 # Build a Deep Residual Shrinkage Network with 3 blocks 122 net = tflearn.input_data(shape=[None, 32, 32, 3], 123 data_preprocessing=img_prep, 124 data_augmentation=img_aug) 125 net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001) 126 net = residual_shrinkage_block(net, 1, 16) 127 net = residual_shrinkage_block(net, 1, 32, downsample=True) 128 net = residual_shrinkage_block(net, 1, 32, downsample=True) 129 net = tflearn.batch_normalization(net) 130 net = tflearn.activation(net, 'relu') 131 net = tflearn.global_avg_pool(net) 132 # Regression 133 net = tflearn.fully_connected(net, 10, activation='softmax') 134 mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True) 135 net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy') 136 # Training 137 model = tflearn.DNN(net, checkpoint_path='model_cifar10', 138 max_checkpoints=10, tensorboard_verbose=0, 139 clip_gradients=0.) 140 141 model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500, 142 show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10') 143 144 training_acc = model.evaluate(X, Y)[0] 145 validation_acc = model.evaluate(testX, testY)[0]
文献来源
M. Zhao, S, Zhong, X. Fu, et al. Deep residual shrinkage networks for fault diagnosis. IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
https://ieeexplore.ieee.org/document/8850096/
源代码