PyTorch学习系列(十)——如何在训练时固定一些层?
2017-05-27 10:52
218 查看
有时我们会用其他任务(如分类)预训练好网络,然后固定卷积层作为图像特征提取器,然后用当前任务的数据只训练全连接层。那么PyTorch如何在训练时固定底层只更新上层呢?这意味着我们希望反向传播计算梯度时,我们只希望计算到最上面的卷积层,对于卷积层,我们并不希望计算梯度并用梯度来更新参数。
我们知道,网络中的所有操作对象都是Variable对象,而Variable有两个参数可以用于这个目的:requires_grad和volatile。
在计算图中,如果有一个输入的requires_grad是True,那么输出的requires_grad也是True。只有在所有输入的requires_grad都为False时,输出的requires_grad才为False。
在训练时如果想要固定网络的底层,那么可以令这部分网络对应子图的参数requires_grad为False。这样,在反向过程中就不会计算这些参数对应的梯度:
我们知道,网络中的所有操作对象都是Variable对象,而Variable有两个参数可以用于这个目的:requires_grad和volatile。
requires_grad=False
在用户手动定义Variable时,参数requires_grad默认值是False。而在Module中的层在定义时,相关Variable的requires_grad参数默认是True。在计算图中,如果有一个输入的requires_grad是True,那么输出的requires_grad也是True。只有在所有输入的requires_grad都为False时,输出的requires_grad才为False。
>>>x = Variable(torch.randn(2, 3), requires_grad=True) >>>y = Variable(torch.randn(2, 3), requires_grad=False) >>>z = Variable(torch.randn(2, 3), requires_grad=False) >>>out1 = x+y >>>out1.requires_grad True >>>out2 = y+z >>>out2.requires_grad False
在训练时如果想要固定网络的底层,那么可以令这部分网络对应子图的参数requires_grad为False。这样,在反向过程中就不会计算这些参数对应的梯度:
model = torchvision.models.resnet18(pretrained=True) for param in model.parameters():#nn.Module有成员函数parameters() param.requires_grad = False # Replace the last fully-connected layer # Parameters of newly constructed modules have requires_grad=True by default model.fc = nn.Linear(512, 100)#resnet18中有self.fc,作为前向过程的最后一层。 # Optimize only the classifier optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)#optimizer用于更新网络参数,默认情况下更新所有的参数
volatile=True
Variable的参数volatile=True和requires_grad=False的功能差不多,但是volatile的力量更大。当有一个输入的volatile=True时,那么输出的volatile=True。volatile=True推荐在模型的推理过程(测试)中使用,这时只需要令输入的voliate=True,保证用最小的内存来执行推理,不会保存任何中间状态。>>> regular_input = Variable(torch.randn(5, 5)) >>> volatile_input = Variable(torch.randn(5, 5), volatile=True) >>> model = torchvision.models.resnet18(pretrained=True) >>> model(regular_input).requires_grad #输出的requires_grad应该是True,因为中间层的Variable的requires_grad默认是True True >>> model(volatile_input).requires_grad#输出的requires_grad是False,因为输出的volatile是True(等价于requires_grad是False) False >>> model(volatile_input).volatile True
相关文章推荐
- PyTorch学习系列(十五)——如何加载预训练模型?
- PyTorch学习系列(十六)——如何使用cuda进行训练?
- Caffe学习系列(23):如何将别人训练好的model用到自己的数据上
- CSS 布局实例系列(三)如何实现一个左右宽度固定,中间自适应的三列布局——也聊聊双飞翼
- 【特征工程系列2】如何获得训练数据的标签?
- Caffe学习系列(23):如何将别人训练好的model用到自己的数据上
- caffe学习系列(10):如何测试caffe训练出来的模型
- 深度学习入门系列博客(严重推荐)--如何训练 梯度消失 梯度爆炸等解释的明确
- Pytorch学习系列(八)——训练神经网络
- CSS 布局实例系列(二)如何通过 CSS 实现一个左边固定宽度、右边自适应的两列布局
- CSS 布局实例系列(二)如何通过 CSS 实现一个左边固定宽度、右边自适应的两列布局
- PyTorch学习系列(十四)——保存训练好的模型
- Caffe学习系列(23):如何将别人训练好的model用到自己的数据上
- Caffe学习系列(23):如何将别人训练好的model用到自己的数据上
- Caffe学习系列(23):如何将别人训练好的model用到自己的数据上
- iTextSharp 问题系列:在PDFPtable里加入图像,如何控制图像大小
- struts体系如何测试系列二
- 局域网乐趣系列一:如何共享上网,如何在特定的环境中共用一个账号,如何做代理服务器,如何代理上网,如何代理上qq
- MachII HowTo系列教程的译文: Mach-II 如何开发Listener
- 如何提升在使用DevExpress 系列DataGrid的性能问题