数据噪声强?尝试一下深度残差收缩网络(附代码)

深度残差收缩网络是一种致力于从强噪声数据中学习特征的深度学习方法,是由“深度残差网络”和“收缩”两个部分组成的。

一方面,“深度残差网络”目前已经成为了深度学习领域的基础网络。

另一方面,“收缩”指的是软阈值化,是许多信号降噪算法的关键步骤。

更重要地,在深度残差收缩网络中,软阈值化所需要的阈值,实质上是在注意力机制下自动设置的,从而避免了人工设置阈值的麻烦。

在本文中,我们首先对残差网络、软阈值化和注意力机制的相关基础进行了简要的回顾,然后对深度残差收缩网络的动机、算法和应用展开解读。

 


 

1. 相关基础

1.1 残差网络

残差网络(或称为深度残差网络、深度残差学习,英文ResNet)属于一种卷积神经网络。相较于普通的卷积神经网络,残差网络采用了跨层的恒等连接,以减轻卷积神经网络的训练难度。残差网络的一种常见的基本模块如图1所示。

1.2 软阈值化

软阈值化是许多信号降噪方法的核心步骤。它的用处是将绝对值低于某个阈值的特征置为零,将其他的特征也朝着零进行调整,也就是收缩。在这里,阈值是一个需要预先设置的参数,其取值大小对于降噪的结果有着直接的影响。软阈值化的输入与输出之间的关系如图所示。

可以看出,软阈值化是一种非线性变换,有着与ReLU激活函数相近的性质:梯度要么是0,要么是1。因此,软阈值化也能够作为神经网络的激活函数。事实上,一些神经网络已经将软阈值函数作为激活函数进行了使用。

1.3 注意力机制

注意力机制就是将注意力集中于局部关键信息的机制,可以分成两步:第一,通过全局扫描,发现局部有用信息;第二,增强有用信息并抑制冗余信息。

Squeeze-and-Excitation Network是一种非常经典的注意力机制下的深度学习方法。它可以通过一个小型的子网络,自动学习得到一组权重,对特征图的各个通道进行加权。其含义在于,某些特征通道是较为重要的,而另一些特征通道是信息冗余的;那么,我们就可以通过这种方式增强有用特征通道、削弱冗余特征通道。Squeeze-and-Excitation Network的一种基本模块如下图所示。

值得指出的是,通过这种方式,每个样本都可以有自己独特的一组权重,可以根据样本自身的特点,进行独特的特征通道加权调整。例如,样本A的第一特征通道是重要的,第二特征通道是不重要的;而样本B的第一特征通道是不重要的,第二特征通道是重要的;通过这种方式,样本A可以有自己的一组权重,以加强第一特征通道,削弱第二特征通道;同样地,样本B可以有自己的一组权重,以削弱第一特征通道,加强第二特征通道。

 


 

2. 深度残差收缩网络理论

2.1 动机

首先,现实世界中的数据,或多或少都含有一些冗余信息。那么我们就可以尝试将软阈值化嵌入残差网络中,以进行冗余信息的消除。

其次,各个样本中冗余信息含量经常是不同的。那么我们就可以借助注意力机制,根据各个样本的情况,自适应地给各个样本设置不同的阈值。

2.2 算法

与残差网络和Squeeze-and-Excitation Network相似,深度残差收缩网络也是由许多基本模块堆叠而成的。每个基本模块都有一个子网络,用于自动学习得到一组阈值,用于特征图的软阈值化。值得指出的是,通过这种方式,每个样本都有着自己独特的一组阈值。深度残差收缩网络的一种基本模块如下图所示。

深度残差收缩网络的大致框架如下图所示,是由输入层、许多基本模块以及最后的全连接输出层等部分所组成的。

2.3 应用

在原始论文中,深度残差收缩网络是应用于基于振动信号的机械设备故障诊断。但是从原理上来讲,深度残差收缩网络面向的是数据集含有冗余信息的情况,而冗余信息是无处不在的。例如,在图像识别的时候,图像中总会包含一些与标签无关的区域;在语音识别的时候,音频中经常会含有各种形式的噪声。因此,深度残差收缩网络,或者说这种在深度学习算法内部集成“注意力机制”+“软阈值化”的思路,有着较为广泛的研究价值和应用前景。

 


 

3.Keras代码示例

  1 #!/usr/bin/env python3
  2 # -*- coding: utf-8 -*-
  3 """
  4 Created on Sat Dec 28 23:24:05 2019
  5 Implemented using TensorFlow 1.0.1 and Keras 2.2.1
  6  
  7 M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 
  8 IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
  9 @author: super_9527
 10 """
 11 
 12 from __future__ import print_function
 13 import keras
 14 import numpy as np
 15 from keras.datasets import mnist
 16 from keras.layers import Dense, Conv2D, BatchNormalization, Activation
 17 from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
 18 from keras.optimizers import Adam
 19 from keras.regularizers import l2
 20 from keras import backend as K
 21 from keras.models import Model
 22 from keras.layers.core import Lambda
 23 K.set_learning_phase(1)
 24 
 25 # Input image dimensions
 26 img_rows, img_cols = 28, 28
 27 
 28 # The data, split between train and test sets
 29 (x_train, y_train), (x_test, y_test) = mnist.load_data()
 30 
 31 if K.image_data_format() == 'channels_first':
 32     x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
 33     x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
 34     input_shape = (1, img_rows, img_cols)
 35 else:
 36     x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
 37     x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
 38     input_shape = (img_rows, img_cols, 1)
 39 
 40 # Noised data
 41 x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
 42 x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
 43 print('x_train shape:', x_train.shape)
 44 print(x_train.shape[0], 'train samples')
 45 print(x_test.shape[0], 'test samples')
 46 
 47 # convert class vectors to binary class matrices
 48 y_train = keras.utils.to_categorical(y_train, 10)
 49 y_test = keras.utils.to_categorical(y_test, 10)
 50 
 51 
 52 def abs_backend(inputs):
 53     return K.abs(inputs)
 54 
 55 def expand_dim_backend(inputs):
 56     return K.expand_dims(K.expand_dims(inputs,1),1)
 57 
 58 def sign_backend(inputs):
 59     return K.sign(inputs)
 60 
 61 def pad_backend(inputs, in_channels, out_channels):
 62     pad_dim = (out_channels - in_channels)//2
 63     inputs = K.expand_dims(inputs,-1)
 64     inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
 65     return K.squeeze(inputs, -1)
 66 
 67 # Residual Shrinakge Block
 68 def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
 69                              downsample_strides=2):
 70     
 71     residual = incoming
 72     in_channels = incoming.get_shape().as_list()[-1]
 73     
 74     for i in range(nb_blocks):
 75         
 76         identity = residual
 77         
 78         if not downsample:
 79             downsample_strides = 1
 80         
 81         residual = BatchNormalization()(residual)
 82         residual = Activation('relu')(residual)
 83         residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), 
 84                           padding='same', kernel_initializer='he_normal', 
 85                           kernel_regularizer=l2(1e-4))(residual)
 86         
 87         residual = BatchNormalization()(residual)
 88         residual = Activation('relu')(residual)
 89         residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', 
 90                           kernel_regularizer=l2(1e-4))(residual)
 91         
 92         # Calculate global means
 93         residual_abs = Lambda(abs_backend)(residual)
 94         abs_mean = GlobalAveragePooling2D()(residual_abs)
 95         
 96         # Calculate scaling coefficients
 97         scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', 
 98                        kernel_regularizer=l2(1e-4))(abs_mean)
 99         scales = BatchNormalization()(scales)
100         scales = Activation('relu')(scales)
101         scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
102         scales = Lambda(expand_dim_backend)(scales)
103         
104         # Calculate thresholds
105         thres = keras.layers.multiply([abs_mean, scales])
106         
107         # Soft thresholding
108         sub = keras.layers.subtract([residual_abs, thres])
109         zeros = keras.layers.subtract([sub, sub])
110         n_sub = keras.layers.maximum([sub, zeros])
111         residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])
112         
113         # Downsampling (it is important to use the pooL-size of (1, 1))
114         if downsample_strides > 1:
115             identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)
116             
117         # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution)
118         if in_channels != out_channels:
119             identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity)
120         
121         residual = keras.layers.add([residual, identity])
122     
123     return residual
124 
125 
126 # define and train a model
127 inputs = Input(shape=input_shape)
128 net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
129 net = residual_shrinkage_block(net, 1, 8, downsample=True)
130 net = BatchNormalization()(net)
131 net = Activation('relu')(net)
132 net = GlobalAveragePooling2D()(net)
133 outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
134 model = Model(inputs=inputs, outputs=outputs)
135 model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
136 model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))
137 
138 # get results
139 K.set_learning_phase(0)
140 DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
141 print('Train loss:', DRSN_train_score[0])
142 print('Train accuracy:', DRSN_train_score[1])
143 DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
144 print('Test loss:', DRSN_test_score[0])
145 print('Test accuracy:', DRSN_test_score[1])

  


 

4.TFLean代码示例  

  1 #!/usr/bin/env python3
  2 # -*- coding: utf-8 -*-
  3 """
  4 Created on Mon Dec 23 21:23:09 2019
  5 Implemented using TensorFlow 1.0 and TFLearn 0.3.2
  6  
  7 M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, 
  8 IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
  9  
 10 @author: super_9527
 11 """
 12   
 13 from __future__ import division, print_function, absolute_import
 14   
 15 import tflearn
 16 import numpy as np
 17 import tensorflow as tf
 18 from tflearn.layers.conv import conv_2d
 19   
 20 # Data loading
 21 from tflearn.datasets import cifar10
 22 (X, Y), (testX, testY) = cifar10.load_data()
 23   
 24 # Add noise
 25 X = X + np.random.random((50000, 32, 32, 3))*0.1
 26 testX = testX + np.random.random((10000, 32, 32, 3))*0.1
 27   
 28 # Transform labels to one-hot format
 29 Y = tflearn.data_utils.to_categorical(Y,10)
 30 testY = tflearn.data_utils.to_categorical(testY,10)
 31   
 32 def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
 33                    downsample_strides=2, activation='relu', batch_norm=True,
 34                    bias=True, weights_init='variance_scaling',
 35                    bias_init='zeros', regularizer='L2', weight_decay=0.0001,
 36                    trainable=True, restore=True, reuse=False, scope=None,
 37                    name="ResidualBlock"):
 38       
 39     # residual shrinkage blocks with channel-wise thresholds
 40   
 41     residual = incoming
 42     in_channels = incoming.get_shape().as_list()[-1]
 43   
 44     # Variable Scope fix for older TF
 45     try:
 46         vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
 47                                    reuse=reuse)
 48     except Exception:
 49         vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
 50   
 51     with vscope as scope:
 52         name = scope.name #TODO
 53   
 54         for i in range(nb_blocks):
 55   
 56             identity = residual
 57   
 58             if not downsample:
 59                 downsample_strides = 1
 60   
 61             if batch_norm:
 62                 residual = tflearn.batch_normalization(residual)
 63             residual = tflearn.activation(residual, activation)
 64             residual = conv_2d(residual, out_channels, 3,
 65                              downsample_strides, 'same', 'linear',
 66                              bias, weights_init, bias_init,
 67                              regularizer, weight_decay, trainable,
 68                              restore)
 69   
 70             if batch_norm:
 71                 residual = tflearn.batch_normalization(residual)
 72             residual = tflearn.activation(residual, activation)
 73             residual = conv_2d(residual, out_channels, 3, 1, 'same',
 74                              'linear', bias, weights_init,
 75                              bias_init, regularizer, weight_decay,
 76                              trainable, restore)
 77               
 78             # get thresholds and apply thresholding
 79             abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
 80             scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
 81             scales = tflearn.batch_normalization(scales)
 82             scales = tflearn.activation(scales, 'relu')
 83             scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
 84             scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
 85             thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
 86             # soft thresholding
 87             residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
 88               
 89   
 90             # Downsampling
 91             if downsample_strides > 1:
 92                 identity = tflearn.avg_pool_2d(identity, 1,
 93                                                downsample_strides)
 94   
 95             # Projection to new dimension
 96             if in_channels != out_channels:
 97                 if (out_channels - in_channels) % 2 == 0:
 98                     ch = (out_channels - in_channels)//2
 99                     identity = tf.pad(identity,
100                                       [[0, 0], [0, 0], [0, 0], [ch, ch]])
101                 else:
102                     ch = (out_channels - in_channels)//2
103                     identity = tf.pad(identity,
104                                       [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
105                 in_channels = out_channels
106   
107             residual = residual + identity
108   
109     return residual
110   
111   
112 # Real-time data preprocessing
113 img_prep = tflearn.ImagePreprocessing()
114 img_prep.add_featurewise_zero_center(per_channel=True)
115   
116 # Real-time data augmentation
117 img_aug = tflearn.ImageAugmentation()
118 img_aug.add_random_flip_leftright()
119 img_aug.add_random_crop([32, 32], padding=4)
120   
121 # Build a Deep Residual Shrinkage Network with 3 blocks
122 net = tflearn.input_data(shape=[None, 32, 32, 3],
123                          data_preprocessing=img_prep,
124                          data_augmentation=img_aug)
125 net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
126 net = residual_shrinkage_block(net, 1, 16)
127 net = residual_shrinkage_block(net, 1, 32, downsample=True)
128 net = residual_shrinkage_block(net, 1, 32, downsample=True)
129 net = tflearn.batch_normalization(net)
130 net = tflearn.activation(net, 'relu')
131 net = tflearn.global_avg_pool(net)
132 # Regression
133 net = tflearn.fully_connected(net, 10, activation='softmax')
134 mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
135 net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
136 # Training
137 model = tflearn.DNN(net, checkpoint_path='model_cifar10',
138                     max_checkpoints=10, tensorboard_verbose=0,
139                     clip_gradients=0.)
140   
141 model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
142           show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
143   
144 training_acc = model.evaluate(X, Y)[0]
145 validation_acc = model.evaluate(testX, testY)[0]

  

文献来源 

M. Zhao, S, Zhong, X. Fu, et al. Deep residual shrinkage networks for fault diagnosis. IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 

https://ieeexplore.ieee.org/document/8850096/

 

源代码 

https://github.com/zhao62/Deep-Residual-Shrinkage-Networks

posted on 2020-03-12 19:27  fuxuyun  阅读(711)  评论(0编辑  收藏  举报

导航