您的位置:首页 > 其它

深度之眼Pytorch框架训练营第四期——模型创建与nn.Module

2020-06-04 07:55 330 查看

文章目录

模型创建与
nn.Module

1、模型创建步骤

模型的创建示意图如下:

从上图中可以看出,模型的创建与权值初始化共同构成了模型,模型的创建只要包括了:

  • 构建网络层:卷积层,池化层,激活函数等;
  • 拼接网络层:网络层有构建网络层后,需要进行网络层的拼接,拼接成LeNetLeNetLeNet,AlexNetAlexNetAlexNet和ResNetResNetResNet等
    创建好模型后,需要对模型进行权值初始化,
    PyTorch
    中的初始化方法主要有:
    Xavier
    Kaiming
    ,均匀分布,正态分布等方法。

2、
nn.Module

  • 第一部分中讲到的模型的创建权值初始化
    PyTorch
    中均需要通过
    nn.Module
    来完成,
    nn.Module
    是整个模块的根基
  • nn.Module
    torch.nn
    中的模块,
    torch.nn
    中一共有四个模块,如下图所示:

  • nn.Module
    中有八个重要的属性用于管理整个模型:
  • parameters
    : 存储管理
    nn.Parameter
  • modules
    :存储管理
    nn.Module
  • buffers
    :存储管理缓冲属性,如BN层中的
    running_mean
  • ***_hooks
    :共有5个,存储管理钩子函数
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()

3、以
LeNet
模型为例探究
nn.Module

  • 如图所示,
    LeNet
    由很多网络层构成,包括两个卷积层,两个池化层和三个全连接层
    LeNet: Conv1 -> pool1 -> Conv2 -> pool2 -> fc1 -> fc2 -> fc3

  • 将上图转为一个计算图的形式,如下图所示,计算图有两个主要的概念:一个是节点一个是边,节点就是张量数据,边就是运算,在图中就是箭头
  • 构建模型有两要素,第一是构建子模块,比如
    LeNet
    是由很多网络层构成的,所以首先得构建子模块中的网络层;构建好网络层后,第二是拼接子模块,按照一定拓扑结构拼接子模块就可以得到模型,构建子模块需要用到
    __init__()
    函数,而拼接子模块需要用到
    forward()
    函数,下面针对这两个函数进行讲解
(1)初始化部分:
__init__()
class LeNet(nn.Module):
def __init__(self, classes):
super(LeNet, self).__init__()    # 继承父类nn.Module的初始化
self.conv1 = nn.Conv2d(3, 6, 5)    # 卷积层,卷积核为5*5,输入通道为3,输出通道为6
self.conv2 = nn.Conv2d(6, 16, 5)    # 卷积层
self.fc1 = nn.Linear(16*5*5, 120)c    # 全连接层
self.fc2 = nn.Linear(120, 84)	# 全连接层
self.fc3 = nn.Linear(84, classes)	# 全连接层

#####(2)拼接部分:

forward()

def forward(self, x):
out = F.relu(self.conv1(x))  # import torch.nn.functional as F
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
(3)
nn.Module
的属性构建

nn.Module
的属性构建会在
module
类中进行属性赋值的时候会被
setattr()
函数拦截,在这个函数当中会判断即将要赋值的数据类型是否是
nn.parameters
类,如果是的话就会存储到
parameters
字典中;如果是
module
类就会存储到
modul
字典中

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: