Pytorch:利用torch.nn.Modules.parameters修改模型参数
2022-05-20 21:42
751 查看
1. 关于parameters()方法
Pytorch中继承了
torch.nn.Module的模型类具有
named_parameters()/parameters()方法,这两个方法都会返回一个用于迭代模型参数的迭代器(
named_parameters还包括参数名字):
import torch net = torch.nn.LSTM(input_size=512, hidden_size=64) print(net.parameters()) print(net.named_parameters()) # <generator object Module.parameters at 0x12a4e9890> # <generator object Module.named_parameters at 0x12a4e9890>
我们可以将
net.parameters()迭代器和将
net.named_parameters()转化为列表类型,前者列表元素是模型参数,后者是包含参数名和模型参数的元组。
当然,我们更多的是对迭代器直接进行迭代:
for param in net.parameters(): print(param.shape) # torch.Size([256, 512]) # torch.Size([256, 64]) # torch.Size([256]) # torch.Size([256]) for name, param in net.named_parameters(): print(name, param.shape) # weight_ih_l0 torch.Size([256, 512]) # weight_hh_l0 torch.Size([256, 64]) # bias_ih_l0 torch.Size([256]) # bias_hh_l0 torch.Size([256])
我们知道,Pytorch在进行优化时需要给优化器传入这个参数迭代器,如:
from torch.optim import RMSprop optimizer = RMSprop(net.parameters(), lr=0.01)
2. 关于参数修改
那么底层具体是怎么对参数进行修改的呢?
我们在博客《Python对象模型与序列迭代陷阱》中介绍过,Python序列中本身存放的就是对象的引用,而迭代器返回的是序列中的对象的二次引用,如果序列的引用指向基础数据类型,则是不可以通过遍历序列进行修改的,如:
my_list = [1, 2, 3, 4] for x in my_list: x += 1 print(my_list) #[1, 2, 3, 4]
而序列中的引用指向复合数据类型,则可以通过遍历序列来完成修改操作,如:
my_list = [[1, 2],[3, 4]] for sub_list in my_list: sub_list[0] += 1 print(my_list) # [1, 2, 3, 4] # [[2, 2], [4, 4]]
具体原理可参照该篇博客,此处我就不在赘述。这里想提到的是,用
net.parameters()/net.named_parameters()来迭代并修改参数,本质上就是上述第二种对复合数据类型序列的修改。我们可以如下写:
for param in net.parameters(): with torch.no_grad(): param += 1
with torch.no_grad():表示将将所要修改的张量关闭梯度计算。所增加的1会广播到
param张量的中的每一个元素上。上述操作本质上为:
for param in net.parameters(): with torch.no_grad(): param += torch.ones(param.shape)
但是需要注意,如果我们想让参数全部置为0,切不可像下列这样写:
for param in net.parameters(): with torch.no_grad(): param = torch.zeros(param.shape)
param是二次引用,
param=0操作再语义上会被解释为让
param这个二次引用去指向新的全0张量对象,但是对参数张量本身并不会产生任何变动。该操作实际上类似下列这种操作:
list_1 = [1, 2] list_2 = list_1 list_2 = [0, 0] print(list_1) # [1, 2]
修改二次引用
list_2自然不会影响到
list_1引用的对象。
下面让我们纠正这种错误,采用下列方法直接来将参数张量中的所有数值置0:
for param in net.parameters(): with torch.no_grad(): param[:] = 0 #张量类型自带广播操作,等效于param[:] = torch.zeros(param.shape)
这时语义上就类似
list_1 = [1, 2] list_2 = list_1 list_2[:] = [0, 0] print(list_1) # [0, 0]
自然就能完成修改的操作了。
参考
相关文章推荐
- pytorch中的pre-train函数模型或者旧的模型的引用及修改(增减网络层,修改某层参数等) finetune微调等
- pytorch:在网络中添加可训练参数,修改预训练权重文件
- 画pytorch模型图,以及参数计算的方法
- 利用pytorch构建简单的CNN模型(二)
- PyTorch中使用预训练的模型初始化网络的一部分参数
- 深度之眼Pytorch框架训练营第四期——模型创建与nn.Module
- Pytorch——把模型的所有参数的梯度清0
- PyTorch 入门实战(四)——利用Torch.nn构建卷积神经网络
- Pytorch:利用自带模型自定义构建VGG16网络以及其他网络
- PyTorch(一):PyTorch基础(PyTorch安装、Tensor、autograd、torch.nn、模型处理、数据处理)
- [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module
- 解决了PyTorch 使用torch.nn.DataParallel 进行多GPU训练的一个BUG:模型(参数)和数据不在相同设备上
- pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件
- (继)pytorch中的pretrain模型网络结构修改
- 深度学习:pytorch用预训练pre-train模型微调参数
- Facebook 发布 PyTorch Hub:一行代码实现经典模型调用!
- pytorch forward两个参数实例
- Pytorch——torch.nn.Sequential()详解
- py-faster-rcnn训练自己数据集需要修改的参数
- 11G利用隐含参数,修改用户名