pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
2017-06-15 13:23
507 查看
pytorch 的 hook 机制
在看pytorch官方文档的时候,发现在
nn.Module部分和
Variable部分均有
hook的身影。感到很神奇,因为在使用
tensorflow的时候没有碰到过这个词。所以打算一探究竟。
Variable 的 hook
register_hook(hook)
注册一个backward钩子。
每次
gradients被计算的时候,这个
hook都被调用。
hook应该拥有以下签名:
hook(grad) -> Variable or None
hook不应该修改它的输入,但是它可以返回一个替代当前梯度的新梯度。
这个函数返回一个 句柄(
handle)。它有一个方法
handle.remove(),可以用这个方法将
hook从
module移除。
例子:
v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True) h = v.register_hook(lambda grad: grad * 2) # double the gradient v.backward(torch.Tensor([1, 1, 1])) #先计算原始梯度,再进hook,获得一个新梯度。 print(v.grad.data) h.remove() # removes the hook1
2
3
4
5
6
1
2
3
4
5
6
2 2 2 [torch.FloatTensor of size 3]1
2
3
4
1
2
3
4
nn.Module的hook
register_forward_hook(hook)
在module上注册一个
forward hook。
每次调用
forward()计算输出的时候,这个
hook就会被调用。它应该拥有以下签名:
hook(module, input, output) -> None
hook不应该修改
input和
output的值。
这个函数返回一个 句柄(
handle)。它有一个方法
handle.remove(),可以用这个方法将
hook从
module移除。
看这个解释可能有点蒙逼,但是如果要看一下
nn.Module的源码怎么使用
hook的话,那就乌云尽散了。
先看
register_forward_hook
def register_forward_hook(self, hook): handle = hooks.RemovableHandle(self._forward_hooks) self._forward_hooks[handle.id] = hook return handle1
2
3
4
5
1
2
3
4
5
这个方法的作用是在此
module上注册一个
hook,函数中第一句就没必要在意了,主要看第二句,是把注册的
hook保存在
_forward_hooks字典里。
再看
nn.Module的
__call__方法(被阉割了,只留下需要关注的部分):
def __call__(self, *input, **kwargs): result = self.forward(*input, **kwargs) for hook in self._forward_hooks.values(): #将注册的hook拿出来用 hook_result = hook(self, input, result) ... return result1
2
3
4
5
6
7
8
1
2
3
4
5
6
7
8
可以看到,当我们执行
model(x)的时候,底层干了以下几件事:
调用
forward方法计算结果
判断有没有注册
forward_hook,有的话,就将
forward的输入及结果作为
hook的实参。然后让
hook自己干一些不可告人的事情。
看到这,我们就明白
hook签名的意思了,还有为什么
hook不能修改
input的
output的原因。
小例子:
import torch from torch import nn import torch.functional as F from torch.autograd import Variable def for_hook(module, input, output): print(module) for val in input: print("input val:",val) for out_val in 14fd1 output: print("output val:", out_val) class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x): return x+1 model = Model() x = Variable(torch.FloatTensor([1]), requires_grad=True) handle = model.register_forward_hook(for_hook) print(model(x)) handle.remove()1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
register_backward_hook
在module上注册一个
bachward hook。此方法目前只能用在
Module上,不能用在
Container上,当
Module的forward函数中只有一个
Function的时候,称为
Module,如果
Module包含其它
Module,称之为
Container
每次计算
module的
inputs的梯度的时候,这个
hook会被调用。
hook应该拥有下面的
signature。
hook(module, grad_input, grad_output) -> Tensor or None
如果
module有多个输入输出的话,那么
grad_input
grad_output将会是个
tuple。
hook不应该修改它的
arguments,但是它可以选择性的返回关于输入的梯度,这个返回的梯度在后续的计算中会替代
grad_input。
这个函数返回一个 句柄(
handle)。它有一个方法
handle.remove(),可以用这个方法将
hook从
module移除。
从上边描述来看,
backward hook似乎可以帮助我们处理一下计算完的梯度。看下面
nn.Module中
register_backward_hook方法的实现,和
register_forward_hook方法的实现几乎一样,都是用字典把注册的
hook保存起来。
def register_backward_hook(self, hook): handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle1
2
3
4
1
2
3
4
先看个例子来看一下
hook的参数代表了什么:
import torch from torch.autograd import Variable from torch.nn import Parameter import torch.nn as nn import math def bh(m,gi,go): print("Grad Input") print(gi) print("Grad Output") print(go) return gi[0]*0,gi[1]*0 class Linear(nn.Module): def __init__(self, in_features, out_features, bias=True): super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input): if self.bias is None: return self._backend.Linear()(input, self.weight) else: return self._backend.Linear()(input, self.weight, self.bias) x=Variable(torch.FloatTensor([[1, 2, 3]]),requires_grad=True) mod=Linear(3, 1, bias=False) mod.register_backward_hook(bh) # 在这里给module注册了backward hook out=mod(x) out.register_hook(lambda grad: 0.1*grad) #在这里给variable注册了 hook out.backward() print(['*']*20) print("x.grad", x.grad) print(mod.weight.grad)1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
Grad Input (Variable containing: 1.00000e-02 * 5.1902 -2.3778 -4.4071 [torch.FloatTensor of size 1x3] , Variable containing: 0.1000 0.2000 0.3000 [torch.FloatTensor of size 1x3] ) Grad Output (Variable containing: 0.1000 [torch.FloatTensor of size 1x1] ,) ['*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*'] x.grad Variable containing: 0 -0 -0 [torch.FloatTensor of size 1x3] Variable containing: 0 0 0 [torch.FloatTensor of size 1x3]1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
可以看出,
grad_in保存的是,此模块
Function方法的输入的值的梯度。
grad_out保存的是,此模块
forward方法返回值的梯度。我们不能在
grad_in上直接修改,但是我们可以返回一个新的
new_grad_in作为
Function方法
inputs的梯度。
上述代码对
variable和
module同时注册了
backward hook,这里要注意的是,无论是
module hook还是
variable hook,最终还是注册到
Function上的。这点通过查看
Varible的
register_hook源码和
Module的
__call__源码得知。
Module的register_backward_hook的行为在未来的几个版本可能会改变
BP过程中
Function中的动作可能是这样的
class Function: def __init__(self): ... def forward(self, inputs): ... return outputs def backward(self, grad_outs): ... return grad_ins def _backward(self, grad_outs): hooked_grad_outs = grad_outs for hook in hook_in_outputs: hooked_grad_outs = hook(hooked_grad_outs) grad_ins = self.backward(hooked_grad_outs) hooked_grad_ins = grad_ins for hook in hooks_in_module: hooked_grad_ins = hook(hooked_grad_ins) return hooked_grad_ins1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
关于
pytorch run_backward()的可能实现猜测为。
def run_backward(variable, gradient): creator = variable.creator if creator is None: variable.grad = variable.hook(gradient) return grad_ins = creator._backward(gradient) vars = creator.saved_variables for var, grad in zip(vars, grad_ins): run_backward(var, var.grad)1
2
3
4
5
6
7
8
9
1
2
3
4
5
6
7
8
9
中间Variable的梯度在BP的过程中是保存到GradBuffer中的(C++源码中可以看到), BP完会释放. 如果retain_grads=True的话,就不会被释放
相关文章推荐
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- 关于Hook的一些理解
- IPTV 关于新的项目开始了,准备做连续的技术理解历程,和部分项目过程!
- pytorch学习笔记(九):PyTorch结构介绍
- Pytorch是什么?关于Pytorch!
- pytorch使用过程中遇到的一些问题
- 关于新品开发过程对应的工作流类型的进一步理解
- Mysql加锁过程详解(3)-关于mysql 幻读理解
- 关于浏览器渲染HTML过程的个人理解
- 关于js的callback回调函数以及嵌套回调函数的执行过程理解
- 对 VirtualApp hook过程的理解
- 关于面向对象和面向过程的程序设计思想的思考和理解
- JavaSE 关于复杂对象生成过程与try-finally的理解
- +++++关于INSTANCE RECOVERY过程的理解
- 关于信号函数处理过程中对信号的屏蔽理解。
- 关于INSTANCE RECOVERY过程的理解
- 写协议解析程序的过程(关于通信解析函数的理解)
- 我的关于编程中调用系统库的过程始终不理解???这篇文件解析了pe文件的结构
- pytorch学习笔记(九):PyTorch结构介绍
- Mysql加锁过程详解(2)-关于mysql 幻读理解