您的位置:首页 > 理论基础 > 计算机网络

深度残差收缩网络:注意力机制和软阈值化的集成

2020-03-01 11:35 323 查看

深度残差收缩网络(deep residual shrinkage networks)是深度残差学习(deep residual learning, ResNet)的一种改进形式。具体而言, 深度残差收缩网络是在ResNet的内部集成了注意力机制和软阈值化,加强深度学习方法从含噪声信号中学习判别性特征的能力,提高分类准确率。以下根据自己的理解,对深度残差收缩网络进行解释。

1. 动机

首先,在许多分类任务中,样本中往往含有各种噪声,比如高斯噪声、粉色噪声等。更宽泛地讲,样本中可能包含着与当前分类任务无关的信息。

比如说,在很多情况下,所需要分类的图片,不仅包含与标签所对应的目标物体,而且包含着与标签无关的物体。这些与标签无关的物体,就可以理解为噪声。这些物体所对应的特征,就需要被滤除掉,以免对当前分类任务造成干扰。或者说,如果我们在马路边聊天,我们聊天的声音可能会混杂了一些车辆的鸣笛声、车轮声。如果对这种信号进行语音识别,识别的准确率就会受到鸣笛声、车轮声的干扰。因此,这些鸣笛声、车轮声所对应的特征,就应该在深度学习算法内部被滤除掉,以免对语音识别任务造成影响。

其次,在同一批样本中,各个样本所包含的噪声往往是不同的。

比如说,我们要训练一个猫狗分类器。对于标签为“狗”的五张训练图像,第一张图片可能包含了狗和老鼠,第二张图片可能包含了狗和鹅,第三张图片可能包含了狗和鸡,第四张图片可能包含了狗和兔子,第五张可能包含了狗和鸭子。我们在执行分类任务的时候,就可能会受到老鼠、鹅、鸡、兔子和鸭子这些无关物体的影响,导致分类效果不好。如果我们能够注意到这些无关的老鼠、鹅、鸡、兔子和鸭子,将它们所对应的特征置为零,就有可能提高猫狗分类器的分类效果。

2. 软阈值化

软阈值化,就是将绝对值小于某个阈值的特征置为零,将绝对值大于这个阈值的特征朝着零的方向进行收缩。它的公式为

软阈值化的输出对于输入的导数为

我们可以发现,它的导数要么为1,要么为0。这个性质是和ReLU激活函数是一样的。所以,软阈值化也可以减小梯度消失和梯度爆炸的风险。

在软阈值化中,阈值的取值必须满足一定的条件: 第一,阈值必须是正数;第二,阈值不能太大,否则输出会全部为零。

同时,我们希望,阈值还能够满足第三个条件:每个样本应该有不同的阈值。

这是因为,许多样本所含的噪声量经常是不同的。例如,样本A所含噪声较少,样本B所含噪声较多。那么,在降噪算法里面,样本A的阈值就应该大一点,样本B的阈值就应该小一些。在深度学习算法里,由于这些特征没有明确的物理意义,阈值的大小也无法得到解释。但是道理是相通的,即每个样本应该有不同的阈值。

3. 注意力机制

注意力机制在计算机视觉领域有着非常直观的解释。 我们人类能够通过快速扫描视觉区域,发现目标物体,进而将大部分注意力集中在目标物体上,以获得更多细节信息,同时抑制无关物体所对应的信息。

Squeeze-and-Excitation Network(SENet)是一种广为人知的注意力机制下的深度神经网络。 在不同的样本中,不同的特征通道,在分类任务中的贡献大小,经常是不一样的。SENet通过一个小型子网络,学习得到一组权重,将这组权重与不同通道的特征相乘,以调整不同通道的特征。这个过程,就可以理解为,施加不同的注意力在各个特征通道上(见下图)。

需要指出的是,每个样本,都有自己独特的一组权重。任意两个样本,它们的这些权重,都是不一样的。在SENet中,具体的路径包括,全局均值池化→全连接层→ReLU函数→全连接层→Sigmoid函数。

深度残差收缩网络就采用了相似的子网络来自动地设置软阈值化所需要的阈值。

在红色框里的子网络,学习得到了一组阈值,应用在特征图的各个通道上。

在这个子网络中,首先对输入特征图的所有元素,求它们的绝对值。然后经过全局均值池化和求平均,就获得了一个特征。在这里,为便于描述,将这个特征记为A。在另一条路径中,全局均值池化之后的特征图,被输入到一个全连接网络。这个全连接网络以Sigmoid激活函数作为最后一步,将输出调整到0和1之间,获得一个系数,记为α。最终的阈值就是α×A。最后的话,阈值就是,一个0和1之间的数字×特征图的绝对值的平均值。 通过这种方式,不仅保证了阈值为正数,而且不会太大。

值得指出的是,通过这种方式,不同的样本就有了不同的阈值。在某种程度上,可以理解成一种特殊的注意力机制:注意到与当前任务无关的特征,将它们置为零;或者说,注意到与当前任务有关的特征,将它们保留下来。

4. 通用性

深度残差收缩网络其实可以作为一种通用的分类方法。 虽然这篇文章原本是面向基于振动信号的故障诊断,其实深度残差收缩网络可以应用于其他的分类任务,比如图像分类、语音识别等等。具体而言,在图像分类的任务中,假如图像中存在着很多的其他物体,那么这些物体就可以被认为是“噪声”;深度残差收缩网络或许可以借助注意力机制和软阈值化,将这些“噪声”所对应的特征删除掉,提高图像分类的效果。对于语音识别,在环境较为嘈杂的情况下,比如在马路边聊天的场景,深度残差收缩网络或许可以提高语音识别准确率,或者提供了一种提高语音识别准确率的思路。

参考网址

深度残差收缩网络:(四)注意力机制下的阈值设置

https://www.cnblogs.com/yc-9527/p/11604082.html

代码网址

将深度残差收缩网络用于图像分类。代码中只构建了一个很小的网络,只有3个残差模块。如果为了追求更高准确率的话,可以适当增加深度,增加优化迭代次数,以及适当优化超参数。
https://github.com/zhao62/Deep-Residual-Shrinkage-Networks

原文网址

M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, DOI: 10.1109/TII.2019.2943898

https://ieeexplore.ieee.org/document/8850096

  • 点赞
  • 收藏
  • 分享
  • 文章举报
residual_fan 发布了1 篇原创文章 · 获赞 0 · 访问量 209 私信 关注
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: