您的位置:首页 > 理论基础 > 计算机网络

残差网络resnet理解与pytorch代码实现

2021-10-08 19:34 841 查看

写在前面

​ 深度残差网络(Deep residual network, ResNet)自提出起,一次次刷新CNN模型在ImageNet中的成绩,解决了CNN模型难训练的问题。何凯明大神的工作令人佩服,模型简单有效,思想超凡脱俗。

​ 直观上,提到深度学习,我们第一反应是模型要足够“深”,才可以提升模型的准确率。但事实往往不尽如人意,先看一个ResNet论文中提到的实验,当用一个平原网络(plain network)构建很深层次的网络时,56层的网络的表现相比于20层的网络反而更差了。说明网络随着深度的加深,会更加难以训练。

图一:模型退化问题

​ 若模型随着网络深度的增加,准确率先上升,然后达到饱和,深度增加准确率下降。那么如果在模型达到饱和时,后面接上几个恒等变换层,这样可以保证误差不会增加,resnet便是这种思想来解决网络退化问题。

第一部分

模型

假设网络的输入是x, 期望输出为H(x),我们转化一下思路,把网络要学到的H(x)转化为期望输出H(x)与输出x之间的差值F(x) = H(x) - x。当残差接近为0时, 相当于网络在此层仅仅做了恒等变换,而不会使网络的效果下降。

图二:残差结构

残差为什么容易学习?

此处参考一位知乎大佬的分析(原文在文末有链接),因为网络要学习的残差项通常比较小:

其中 分别表示的是第 个残差单元的输入和输出,注意每个残差单元一般包含多层结构。 是残差函数,表示学习到的残差,而 表示恒等映射, 是ReLU激活函数。基于上式,我们求得从浅层 到深层 的学习特征为:

利用链式规则,可以求得反向过程的梯度:

式子的第一个因子 表示的损失函数到达 的梯度,小括号中的1表明短路机制可以无损地传播梯度,而另外一项残差梯度则需要经过带有weights的层,梯度不是直接传递过来的。残差梯度不会那么巧全为-1,而且就算其比较小,有1的存在也不会导致梯度消失。所以残差学习会更容易。要注意上面的推导并不是严格的证明。

深度残差网络结构如下:

第二部分

pytorch代码实现

# -*- coding:utf-8 -*-
# handwritten digits recognition
# Data: MINIST
# model: resnet
# date: 2021.10.8 14:18

import math
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.utils.data as Data
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt

train_curve = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# param
batch_size = 100
n_class = 10
padding_size = 15
epoches = 10

train_dataset = torchvision.datasets.MNIST('./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST('./data/', train=False, transform=transforms.ToTensor(), download=False)
train = Data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
test = Data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=5)

def gelu(x):
"Implementation of the gelu activation function by Hugging Face"
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class ResBlock(nn.Module):
# 残差块
def __init__(self, in_size, out_size1, out_size2):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels = in_size,
out_channels = out_size1,
kernel_size = 3,
stride = 2,
padding = padding_size
)
self.conv2 = nn.Conv2d(
in_channels = out_size1,
out_channels = out_size2,
kernel_size = 3,
stride = 2,
padding = padding_size
)
self.batchnorm1 = nn.BatchNorm2d(out_size1)
self.batchnorm2 = nn.BatchNorm2d(out_size2)

def conv(self, x):
# gelu效果比relu好呀哈哈
x = gelu(self.batchnorm1(self.conv1(x)))
x = gelu(self.batchnorm2(self.conv2(x)))
return x

def forward(self, x):
# 残差连接
return x + self.conv(x)

# resnet
class Resnet(nn.Module):
def __init__(self, n_class = n_class):
super(Resnet, self).__init__()
self.res1 = ResBlock(1, 8, 16)
self.res2 = ResBlock(16, 32, 16)
self.conv = nn.Conv2d(
in_channels = 16,
out_channels = n_class,
kernel_size = 3,
stride = 2,
padding = padding_size
)
self.batchnorm = nn.BatchNorm2d(n_class)
self.max_pooling = nn.AdaptiveAvgPool2d(1)

def forward(self, x):
# x: [bs, 1, h, w]
# x = x.view(-1, 1, 28, 28)
x = self.res1(x)
x = self.res2(x)
x = self.max_pooling(self.batchnorm(self.conv(x)))

return x.view(x.size(0), -1)

resnet = Resnet().to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=resnet.parameters(), lr=1e-2, momentum=0.9)

# train
total_step = len(train)
sum_loss = 0
for epoch in range(epoches):
for i, (images, targets) in enumerate(train):
optimizer.zero_grad()
images = images.to(device)
targets = targets.to(device)
preds = resnet(images)

loss = loss_fn(preds, targets)
sum_loss += loss.item()
loss.backward()
optimizer.step()
if (i+1)%100==0:
print('[{}|{}] step:{}/{} loss:{:.4f}'.format(epoch+1, epoches, i+1, total_step, loss.item()))
train_curve.append(sum_loss)
sum_loss = 0

# test
resnet.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test:
images = images.to(device)
labels = labels.to(device)
outputs = resnet(images)
_, maxIndexes = torch.max(outputs, dim=1)
correct += (maxIndexes==labels).sum().item()
total += labels.size(0)

print('in 1w test_data correct rate = {:.4f}'.format((correct/total)*100))

pd.DataFrame(train_curve).plot() # loss曲线

测试了1万条测试集样本结果:

代码链接:

jupyter版本:https://github.com/PouringRain/blog_code/blob/main/deeplearning/resnet.ipynb

py版本:https://github.com/PouringRain/blog_code/blob/main/deeplearning/resnet.py

喜欢的话,给萌新的github仓库一颗小星星哦……^ _^

参考资料:

https://zhuanlan.zhihu.com/p/31852747

https://zhuanlan.zhihu.com/p/80226180

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: