【pytorch源码赏析】nn.XX VS nn.functional.XX
2017-08-02 15:02
519 查看
先来看代码
定义网络的时候,有时候用nn.Conv2d,有时候用F.relu,两者有什么区别吗?
从上面代码可以看出,nn.Conv2d要先声明,像上面的8,9,10行,然后再调用,参与构建模型。而F.relu,直接使用,并没有声明的环节。
看源码。F.relu的源码,nn.Conv2d的源码
F.relu仅仅是一个函数,参数包括输入和计算所需参数,返回计算结果。它不能存储任何上下文信息。所有的function函数都从基类Function派生,实现forward和backward静态方法。而在forward和backward实现内部,调用了C的后端实现。比如下面这个MaxPool1d:
而nn.Conv2d是在F.conv2d外加的一层封装,从Module派生而来。Module是最接近用户的类,因为如果用户要实现自己的模型,也要从Module派生。Module提供了一种嵌套的模型定义方式,用户可以用这种简单的方式,定义很复杂的模型。关于Module最重要的一点是,它实现了__call__方法,在__call__方法里调用了forward()方法。所以,当执行self.Conv2d(x),它就在调用Conv2d里面的forward(),并返回计算结果。
nn.xx仅封装了一部分常用的计算模块,大部分还是在F.xx里面。如果需要自己用c实现operator(大部分情况不用),最后把接口提供给F.xx就可以。
另外,定义损失也是和nn.Conv2d一样的做法,先声明,再使用。
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) self.conv1_2 = nn.Conv2d(64, 3, 3, padding=1) self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def forward(self, x): x = F.relu(self.conv1_1(x)) x = self.pool1(F.relu(self.conv1_2(x))) return x
定义网络的时候,有时候用nn.Conv2d,有时候用F.relu,两者有什么区别吗?
从上面代码可以看出,nn.Conv2d要先声明,像上面的8,9,10行,然后再调用,参与构建模型。而F.relu,直接使用,并没有声明的环节。
看源码。F.relu的源码,nn.Conv2d的源码
# F.relu def relu(input, inplace=False): return _functions.thnn.Threshold.apply(input, 0, 0, inplace)
F.relu仅仅是一个函数,参数包括输入和计算所需参数,返回计算结果。它不能存储任何上下文信息。所有的function函数都从基类Function派生,实现forward和backward静态方法。而在forward和backward实现内部,调用了C的后端实现。比如下面这个MaxPool1d:
class MaxPool1d(Function): @staticmethod def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): ... backend = type2backend[type(input)] backend.SpatialDilatedMaxPooling_updateOutput(...) ... return output, indices @staticmethod def backward(ctx, grad_output, _indices_grad=None): ... grad_input = MaxPool1dBackward.apply(...) return grad_input, None, None, None, None, None, None
而nn.Conv2d是在F.conv2d外加的一层封装,从Module派生而来。Module是最接近用户的类,因为如果用户要实现自己的模型,也要从Module派生。Module提供了一种嵌套的模型定义方式,用户可以用这种简单的方式,定义很复杂的模型。关于Module最重要的一点是,它实现了__call__方法,在__call__方法里调用了forward()方法。所以,当执行self.Conv2d(x),它就在调用Conv2d里面的forward(),并返回计算结果。
nn.xx仅封装了一部分常用的计算模块,大部分还是在F.xx里面。如果需要自己用c实现operator(大部分情况不用),最后把接口提供给F.xx就可以。
另外,定义损失也是和nn.Conv2d一样的做法,先声明,再使用。
相关文章推荐
- 【pytorch源码赏析】Dataset in pytorch
- PyTorch官方中文文档:torch.nn.functional
- 『PyTorch』第十二弹_nn.Module和nn.functional
- Cocos2dx源码赏析(4)之Action动作
- Functional Programming vs. Imperative Programming
- 关于如何使用VS高亮显示无扩展名源码文件的一个小技巧
- VS2010 + .net4.0 一个小程序锁屏软件 支持开机自启动 源码 免积分下载
- 采用Reflector的VS.net插件断点调试无源码DLL
- java1.7集合源码赏析系列:线程池原理
- 【OpenCV图像处理入门学习教程一】OpenCV2 + 3的安装教程与VS2013的开发环境配置 + JPEG压缩源码分析与取反运算修改
- dubbo2.5.6 vs 2.5.7的源码比较
- VS2017 nlog源码查看报错
- winform 入门开发,VS 程序自带的日历控件日期显示格式是xxxx年x月x日,如何改成xxxx-xx-xx 的格式 ?
- 基于新唐M0的XXTEA加密解密算法源码
- jvm hotspot 源码分析 RandomAccessfile vs FileChannel 写文件
- Windows源码安装PyTorch 0.4
- epoll vs select———— epoll内核源码理解
- Windows平台使用VS2013编译VLC源码
- vs2015 去除 git 源代码 绑定,改成向tfs添加源码管理
- VS2013编译python源码