您的位置:首页 > 理论基础 > 计算机网络

残差学习+注意力机制+软阈值函数:残差收缩网络(附代码)

2020-06-05 10:17 435 查看

顾名思义,深度残差收缩网络是在“残差学习ResNet”基础上的一种改进网络,是由“残差学习”和“收缩”两部分所组成的。其中,ResNet在2016年斩获了ImageNet图像识别竞赛的冠军,目前已经成为了深度学习领域的基础网络;“收缩”指的是“软阈值化”,是许多信号降噪算法的关键步骤。在深度残差收缩网络中,软阈值化所需要的阈值,实质上是借助注意力机制设置的。

在本文中,我们首先对残差网络、软阈值化和注意力机制的基础知识进行了简要的回顾,然后对深度残差收缩网络的动机、算法和应用展开解读。

1.基础回顾

1.1 残差网络

从本质上讲,残差网络(又称深度残差网络、深度残差学习)是一种卷积神经网络。相较于普通的卷积神经网络,残差网络采用了跨层恒等连接,以减轻卷积神经网络的训练难度。残差网络的一种基本模块如图所示。

图1 残差网络的一种基本模块

1.2 软阈值化

软阈值化是许多信号降噪方法的核心步骤。它的用处是将绝对值低于某个阈值的特征置为零,将其他的特征也朝着零进行调整,也就是“收缩”。在这里,阈值是一个需要预先设置的参数,其取值大小对于降噪的结果有着直接的影响。软阈值化的输入与输出之间的关系如下图所示。

图2 软阈值化

从图2可以看出,软阈值化是一种非线性变换,有着与ReLU激活函数非常相似的性质:梯度要么是0,要么是1。因此,软阈值化也能够作为神经网络的激活函数。事实上,一些神经网络已经将软阈值化作为激活函数进行了使用。

1.3 注意力机制

注意力机制就是将注意力集中于局部关键信息的机制,可以分为两步:第一,通过扫描全局信息,发现局部有用信息;第二,增强有用信息并抑制冗余信息。

Squeeze-and-Excitation Network是一种非常经典的注意力机制下的深度学习方法。它可以通过一个小型的子网络,自动学习得到一组权重,对特征图的各个通道进行加权。其含义在于,某些特征通道是比较重要的,而另一些特征通道是信息冗余的;那么,我们就可以通过这种方式增强有用特征通道、削弱冗余特征通道。Squeeze-and-Excitation Network的一种基本模块如下图所示。

图3 Squeeze-and-Excitation Network的一种基本模块

值得指出的是,通过这种方式,每个样本都可以有自己独特的一组权重,可以根据样本自身的特点,进行独特的特征通道加权调整。例如,样本A的第一特征通道是重要的,第二特征通道是不重要的;而样本B的第一特征通道是不重要的,第二特征通道是重要的;通过这种方式,样本A可以有自己的一组权重,以加强第一特征通道,削弱第二特征通道;同样地,样本B可以有自己的一组权重,以削弱第一特征通道,加强第二特征通道。

2.深度残差收缩网络理论

2.1 动机

首先,现实世界中的数据,或多或少都含有一些冗余信息。那么我们就可以尝试将软阈值化嵌入残差网络中,以进行冗余信息的消除。

其次,各个样本中冗余信息含量经常是不同的。那么我们就可以借助注意力机制,根据各个样本的情况,自适应地给各个样本设置不同的阈值。

2.2 算法

与残差网络和Squeeze-and-Excitation Network相似,深度残差收缩网络也是由许多基本模块堆叠而成的。每个基本模块都有一个子网络,用于自动学习得到一组阈值,用于特征图的软阈值化。值得指出的是,通过这种方式,每个样本都有着自己独特的一组阈值。深度残差收缩网络的一种基本模块如下图所示。

图4 深度残差收缩网络的一种基本模块

深度残差收缩网络的整体结构如下图所示,是由输入层、许多基本模块以及最后的全连接输出层等组成的。

图5 深度残差收缩网络的整体结构

2.3 应用

在原始论文中,深度残差收缩网络是应用于基于振动信号的旋转机械故障诊断。但是从原理上来讲,深度残差收缩网络面向的是数据集含有冗余信息的情况,而冗余信息是无处不在的。例如,在图像识别的时候,图像中总会包含一些与标签无关的区域;在语音识别的时候,音频中经常会含有各种形式的噪声。因此,深度残差收缩网络,或者说这种“深度学习”+“软阈值化”+“注意力机制”的思路,有着较为广泛的研究前景。

3.Keras和TFLearn程序简介

本程序以图像分类为例,构建了小型的深度残差收缩网络,超参数也未进行优化。为追求高准确率的话,可以适当增加深度,增加训练迭代次数,以及适当调整超参数。下面是Keras程序:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 28 23:24:05 2019
Implemented using TensorFlow 1.0.1 and Keras 2.2.1

M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis,
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""

from __future__ import print_function
import keras
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from keras.optimizers import Adam
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
from keras.layers.core import Lambda
K.set_learning_phase(1)

# Input image dimensions
img_rows, img_cols = 28, 28

# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

def abs_backend(inputs):
return K.abs(inputs)

def expand_dim_backend(inputs):
return K.expand_dims(K.expand_dims(inputs,1),1)

def sign_backend(inputs):
return K.sign(inputs)

def pad_backend(inputs, in_channels, out_channels):
pad_dim = (out_channels - in_channels)//2
inputs = K.expand_dims(inputs,-1)
inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
return K.squeeze(inputs, -1)

# Residual Shrinakge Block
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
downsample_strides=2):

residual = incoming
in_channels = incoming.get_shape().as_list()[-1]

for i in range(nb_blocks):

identity = residual

if not downsample:
downsample_strides = 1

residual = BatchNormalization()(residual)
residual = Activation('relu')(residual)
residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides),
padding='same', kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(residual)

residual = BatchNormalization()(residual)
residual = Activation('relu')(residual)
residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(residual)

# Calculate global means
residual_abs = Lambda(abs_backend)(residual)
abs_mean = GlobalAveragePooling2D()(residual_abs)

# Calculate scaling coefficients
scales = Dense(out_channels, activation=None, kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(abs_mean)
scales = BatchNormalization()(scales)
scales = Activation('relu')(scales)
scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
scales = Lambda(expand_dim_backend)(scales)

# Calculate thresholds
thres = keras.layers.multiply([abs_mean, scales])

# Soft thresholding
sub = keras.layers.subtract([residual_abs, thres])
zeros = keras.layers.subtract([sub, sub])
n_sub = keras.layers.maximum([sub, zeros])
residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])

# Downsampling (it is important to use the pooL-size of (1, 1))
if downsample_strides > 1:
identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)

# Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution)
if in_channels != out_channels:
identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity)

residual = keras.layers.add([residual, identity])

return residual

# define and train a model
inputs = Input(shape=input_shape)
net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = BatchNormalization()(net)
net = Activation('relu')(net)
net = GlobalAveragePooling2D()(net)
outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))

# get results
K.set_learning_phase(0)
DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
print('Train loss:', DRSN_train_score[0])
print('Train accuracy:', DRSN_train_score[1])
DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
print('Test loss:', DRSN_test_score[0])
print('Test accuracy:', DRSN_test_score[1])

下面是TFLearn程序:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 21:23:09 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2

M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis,
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898

@author: super_9527
"""

from __future__ import division, print_function, absolute_import

import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d

# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()

# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1

# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)

def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
downsample_strides=2, activation='relu', batch_norm=True,
bias=True, weights_init='variance_scaling',
bias_init='zeros', regularizer='L2', weight_decay=0.0001,
trainable=True, restore=True, reuse=False, scope=None,
name="ResidualBlock"):

# residual shrinkage blocks with channel-wise thresholds

residual = incoming
in_channels = incoming.get_shape().as_list()[-1]

# Variable Scope fix for older TF
try:
vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse)
except Exception:
vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)

with vscope as scope:
name = scope.name #TODO

for i in range(nb_blocks):

identity = residual

if not downsample:
downsample_strides = 1

if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3,
downsample_strides, 'same', 'linear',
bias, weights_init, bias_init,
regularizer, weight_decay, trainable,
restore)

if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3, 1, 'same',
'linear', bias, weights_init,
bias_init, regularizer, weight_decay,
trainable, restore)

# get thresholds and apply thresholding
abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
scales = tflearn.batch_normalization(scales)
scales = tflearn.activation(scales, 'relu')
scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
# soft thresholding
residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))

# Downsampling
if downsample_strides > 1:
identity = tflearn.avg_pool_2d(identity, 1,
downsample_strides)

# Projection to new dimension
if in_channels != out_channels:
if (out_channels - in_channels) % 2 == 0:
ch = (out_channels - in_channels)//2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch]])
else:
ch = (out_channels - in_channels)//2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch+1]])
in_channels = out_channels

residual = residual + identity

return residual

# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)

# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
data_preprocessing=img_prep,
data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)

model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')

training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

参考文献

M. Zhao, S, Zhong, X. Fu, et al. Deep residual shrinkage networks for fault diagnosis. IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898

https://ieeexplore.ieee.org/document/8850096

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: