您的位置:首页 > 运维架构

tensorflow 之 while_loop

2016-12-16 14:44 453 查看
So, here's what I understand so far. Perhaps you can correct any misunderstandings and also this will serve as a useful resource for the time being for anyone else who runs into this.

tf.while_loop
accepts a list of loop variables, a function mapping loop variables to a boolean, and a function mapping loop variables to a new set of loop variables.
Internally, this is represented using the special nodes Enter, Exit, NextIteration, Switch, and Merge. Enter, Exit, NextIteration are all semantically equivalent to identity ops (they just forward their input to their output, potentially as a reference),
but the fact that they have type Enter, Exit, NextIteration is used by the executor to handle them in a special way. The graph is constructed as follows:
The loop variables are sent through "Enter" nodes.
The Enter node results are then given to "Merge" nodes. During the graph construction, the inputs to the "Merge" nodes are two copies of each enter node; when the NextIteration nodes are constructed, the Merge nodes are fixed by replacing one of the Enter
inputs with a NextIteration input. In this way, every Merge node (one per variable) gets an input from its respective variable's Enter and NextIteration nodes.
The output of the Merge nodes is passed to the condition function, which takes them and outputs a boolean. This boolean is passed to a
LoopCond
node. This boolean, as well as the output of the Merge nodes, is passed to Switch nodes, again one per variable. The Switch nodes output a dead tensor to one of their outputs and a live tensor (the merge node output) to the other one,
depending on the boolean.
The output of the Switch node is sent to an Exit node (one per variable) or to an Identity op (one per variable), depending on whether the loop condition is false.
The identity op output is given to the loop body, and the outputs of the loop body are fed to NextIteration ops; these ops are the ones patched back in as inputs to the Merge nodes.

The executor has special support for these five primitive ops which make this structure into a loop:
The executor has a concept of a Frame, which is essentially the current iteration of the innermost loop. A frame has state, where all the input and output tensors are stored in a flat vector; each op writes its outputs to a subset of the output vector and
gets inputs from a subset of the input vectors; thus, the inputs and outputs of an op can be obtained by just going to the right offset in this vector of Entry values.

A new frame is created when the executor sees an Enter node. A frame is removed when it sees an Exit node. The next iteration of the frame is progressed to when it sees a NextIteration node.
When it sees a NextIteration node, it finds the child of that node (namely the Merge op) and calls
ActivateNode
on it, in order to continue the loop. Since nodes are not marked ready until all their inputs are non-dead, the nodes that get dead inputs from Switch (e.g. the loop is done) will not get run again.
For every loop during forward propagation, a few things have to happen to create the backprop gradient graph:
First of all, a loop is added to the forward propagation which counts the number of iterations. More accurately, the original loop is added to; this works because of the way the primitive ops are designed. This loop starts with a
f_count
Enter node and is created in
control_flow_ops.py

AddForwardLoopCounter
.
A
history_map
is maintained of tensors produced during forward prop, and whenever the backprop needs a tensor from the forward prop, a stack is introduced, and the forward prop has a
StackPush
added to it, while the backprop has a
StackPop
added to it that pops from the same stack. In that manner, the forward prop pushes anything the backprop will need onto a stack, and the backprop slowly consumes that stack.

The description above is not quite complete but I think I probably understand enough for what I want to do.

Questions:

Why is there a
LoopCond
node? Why not pass the output of the condition directly to
Switch
?
What was the motivation for such a seemingly complicated set of primitive ops? It seems like it's possible to build other control structures on top of them – is that the goal? Were these primitive ops chosen because they make it possible to implement the
fairly complex gradient loop generation?
What is an example usecase for
parallel_iterations
? (This is a simple question which might make sense to add to the
tf.while_loop
docs)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: