Pytorch反向求导更新网络参数的方法
2019-08-17 17:57
1291 查看
方法一:手动计算变量的梯度,然后更新梯度
import torch from torch.autograd import Variable # 定义参数 w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True) # 定义输出 d = torch.mean(w1) # 反向求导 d.backward() # 定义学习率等参数 lr = 0.001 # 手动更新参数 w1.data.zero_() # BP求导更新参数之前,需先对导数置0 w1.data.sub_(lr*w1.grad.data)
一个网络中通常有很多变量,如果按照上述的方法手动求导,然后更新参数,是很麻烦的,这个时候可以调用torch.optim
方法二:使用torch.optim
import torch from torch.autograd import Variable import torch.nn as nn import torch.optim as optim # 这里假设我们定义了一个网络,为net steps = 10000 # 定义一个optim对象 optimizer = optim.SGD(net.parameters(), lr = 0.01) # 在for循环中更新参数 for i in range(steps): optimizer.zero_grad() # 对网络中参数当前的导数置0 output = net(input) # 网络前向计算 loss = criterion(output, target) # 计算损失 loss.backward() # 得到模型中参数对当前输入的梯度 optimizer.step() # 更新参数
注意:torch.optim只用于参数更新和对参数的梯度置0,不能计算参数的梯度,在使用torch.optim进行参数更新之前,需要写前向与反向传播求导的代码
以上这篇Pytorch反向求导更新网络参数的方法就是小编分享给大家的全部内容了,希望能给大家一个参考
您可能感兴趣的文章:
相关文章推荐
- Pytorch反向求导更新网络参数
- caffe 如何让反向传播不更新某些层,即固定网络参数
- 神经网络更新参数的几种方法
- Tensorflow笔记:反向传播,搭建神经网络的八股,(损失函数loss,均方误差MSE,反向传播训练方法,学习率)
- Android Volley网络请求框架 实现post方法并带Map参数上传
- LR测试文件/表参数的数据分配和更新方法(十)
- 快递100 官方api技术文档 错误 更新 快递公司网络异常 解决方法
- WCF客户端引用带有 int bool 类型的方法时,会自动加上一个Specified参数的 解决方法 Web Reference for a WCF Service has Extra “IdSpecified” Parameter -摘自网络
- caffe的finetuning是如何更新网络参数的
- jquery uploadify动态更新配置参数方法uploadifySettings()
- Android Volley网络请求框架 实现post方法并带Map参数上传
- Tensorflow笔记:反向传播参数更新推导过程
- 低功耗蓝牙BLE之连接事件、连接参数和更新方法
- pytorch 网络参数初始化
- 低功耗蓝牙BLE之连接事件、连接参数和更新方法
- 低功耗蓝牙BLE之连接事件、连接参数和更新方法
- mysql中max_allowed_packet参数的配置方法(避免大数据写入或者更新失败)
- mysql中max_allowed_packet参数的配置方法(避免大数据写入或者更新失败)
- 删除和修改caffe模型中任意最后一层或者任意层数网络的参数的方法
- 四种方法保障网络参数设置的安全