深度之眼Pytorch框架训练营第四期——模型创建与nn.Module
2020-06-04 07:55
330 查看
文章目录
- 模型创建与` nn.Module`
- 1、模型创建步骤
- 2、`nn.Module`
- 3、以`LeNet`模型为例探究`nn.Module`
- (1)初始化部分:`__init__()`
- (3)`nn.Module`的属性构建
模型创建与nn.Module
1、模型创建步骤
模型的创建示意图如下:
从上图中可以看出,模型的创建与权值初始化共同构成了模型,模型的创建只要包括了:
- 构建网络层:卷积层,池化层,激活函数等;
- 拼接网络层:网络层有构建网络层后,需要进行网络层的拼接,拼接成LeNetLeNetLeNet,AlexNetAlexNetAlexNet和ResNetResNetResNet等
创建好模型后,需要对模型进行权值初始化,PyTorch
中的初始化方法主要有:Xavier
,Kaiming
,均匀分布,正态分布等方法。
2、nn.Module
- 第一部分中讲到的模型的创建与权值初始化在
PyTorch
中均需要通过nn.Module
来完成,nn.Module
是整个模块的根基 nn.Module
是torch.nn
中的模块,torch.nn
中一共有四个模块,如下图所示:
nn.Module
中有八个重要的属性用于管理整个模型:parameters
: 存储管理nn.Parameter
类modules
:存储管理nn.Module
类buffers
:存储管理缓冲属性,如BN层中的running_mean
***_hooks
:共有5个,存储管理钩子函数
self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() self._load_state_dict_pre_hooks = OrderedDict() self._modules = OrderedDict()
3、以LeNet
模型为例探究nn.Module
- 如图所示,
LeNet
由很多网络层构成,包括两个卷积层,两个池化层和三个全连接层(LeNet: Conv1 -> pool1 -> Conv2 -> pool2 -> fc1 -> fc2 -> fc3
)
- 将上图转为一个计算图的形式,如下图所示,计算图有两个主要的概念:一个是节点一个是边,节点就是张量数据,边就是运算,在图中就是箭头
- 构建模型有两要素,第一是构建子模块,比如
LeNet
是由很多网络层构成的,所以首先得构建子模块中的网络层;构建好网络层后,第二是拼接子模块,按照一定拓扑结构拼接子模块就可以得到模型,构建子模块需要用到__init__()
函数,而拼接子模块需要用到forward()
函数,下面针对这两个函数进行讲解
(1)初始化部分:__init__()
class LeNet(nn.Module): def __init__(self, classes): super(LeNet, self).__init__() # 继承父类nn.Module的初始化 self.conv1 = nn.Conv2d(3, 6, 5) # 卷积层,卷积核为5*5,输入通道为3,输出通道为6 self.conv2 = nn.Conv2d(6, 16, 5) # 卷积层 self.fc1 = nn.Linear(16*5*5, 120)c # 全连接层 self.fc2 = nn.Linear(120, 84) # 全连接层 self.fc3 = nn.Linear(84, classes) # 全连接层
#####(2)拼接部分:
forward()
def forward(self, x): out = F.relu(self.conv1(x)) # import torch.nn.functional as F out = F.max_pool2d(out, 2) out = F.relu(self.conv2(out)) out = F.max_pool2d(out, 2) out = out.view(out.size(0), -1) out = F.relu(self.fc1(out)) out = F.relu(self.fc2(out)) out = self.fc3(out) return out
(3)nn.Module
的属性构建
nn.Module的属性构建会在
module类中进行属性赋值的时候会被
setattr()函数拦截,在这个函数当中会判断即将要赋值的数据类型是否是
nn.parameters类,如果是的话就会存储到
parameters字典中;如果是
module类就会存储到
modul字典中
相关文章推荐
- 深度之眼Pytorch框架训练营第四期——模型容器与AlexNet构建
- 深度之眼Pytorch框架训练营第四期——正则化之 Dropout
- 深度之眼Pytorch框架训练营第四期——Hook函数与CAM算法
- 深度之眼Pytorch框架训练营第四期——PyTorch中的可视化工具
- 深度之眼Pytorch框架训练营第四期——PyTorch中学习率调整策略
- 深度之眼Pytorch框架训练营第四期——损失函数
- 深度之眼Pytorch框架训练营第四期——权值初始化
- 深度之眼Pytorch框架训练营第四期——池化、线性、激活函数层
- 深度之眼Pytorch框架训练营第四期——神经网络中的卷积层
- 深度之眼Pytorch框架训练营第四期——数据读取机制中的Dataloader与Dataset
- PyTorch深度学习计算机视觉框架
- 自学pytorch深度学习遇到的坑之模型加载出错
- 深度学习框架之Pytorch学习(一)
- 【异周话题 第 18 期】TensorFlow与PyTorch,深度学习框架你选哪一个?
- 深度学习:pytorch用预训练pre-train模型微调参数
- 用 PyTorch 框架做深度学习会如此简单
- Pytorch基础——从nn.module转写成.py脚本(一)
- windous10下+Anaconda+深度学习框架(TensorFlow cpu/gpu 、Keras、Pytorch)+Cuda+Cudnn+pycharm安装教程及避坑手册
- 深度学习入门笔记(十五):深度学习框架(TensorFlow和Pytorch之争)
- 【PyTorch框架学习】-创建tensor