您的位置:首页 > 其它

pytorch给同一个layer的weight和bias设置不同的学习速率

2018-03-19 16:53 309 查看
经过在pytorch论坛的提问https://discuss.pytorch.org/t/how-to-set-different-learning-rate-for-weight-and-bias-in-one-layer/13450,现在总结如下:

1.使用dict简单粗暴设置,适用于层数较少的模型

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer = nn.Linear(1, 1)
self.layer.weight.data.fill_(1)
self.layer.bias.data.fill_(1)

def forward(self, x):
return self.layer(x)

if __name__=="__main__":
net = Net()
optimizer = optim.Adam([
{'params': net.layer.weight},
{'params': net.layer.bias, 'lr': 0.01}
], lr=0.1, weight_decay=0.0001)
out = net(Variable(torch.Tensor([[1]])))
out.backward()
optimizer.step()
print("weight", net.layer.weight.data.numpy(), "grad", net.layer.weight.grad.data.numpy())
print("bias", net.layer.bias.data.numpy(), "grad", net.layer.bias.grad.data.numpy())


若使用nn.Sequential可能会出现问题net.layer1.0.weight报语法错误,这时候使用net.layer1[0].weight代替即可。

2.使用代码生成dict列表,适用于层数多的layer

例子参考github上wkentaro/pytorch-fcn/blob/master/examples/voc/train_fcn32s.py#L105

其中函数get_parameter()来生成dict

def get_parameters(model, bias=False):
import torch.nn as nn
modules_skipped = (
nn.ReLU,
nn.MaxPool2d,
nn.Dropout2d,
nn.Sequential,
torchfcn.models.FCN32s,
torchfcn.models.FCN16s,
torchfcn.models.FCN8s,
)
for m in model.modules():
if isinstance(m, nn.Conv2d):
if bias:
yield m.bias
else:
yield m.weight
elif isinstance(m, nn.ConvTranspose2d):
# weight is frozen because it is just a bilinear upsampling
if bias:
assert m.bias is None
elif isinstance(m, modules_skipped):
continue
else:
raise ValueError('Unexpected module: %s' % str(m))


应该注意根据情况调整modules_skipped的内容,应用方法如下:

optimizer = optim.SGD([
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': opt.lr * 0.1}
], lr=opt.lr, weight_decay=opt.weight_decay)


但这两个方法都有一个问题,当我的学习速率要随着训练改变的时候,比如step,还不知道要怎么重新加载学习速率,现在是简单粗暴手动修改+load前一次的训练结果来进行的,如何实现类似caffe自动step学习速率的功能,目前没有头绪,如果谁有头绪,求教。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  pytorch
相关文章推荐