-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
@@ -0,0 +1,969 @@ | |||
# Licensed to the Apache Software Foundation (ASF) under one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please rename this file to test_control_flow.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed :-)
test_while_loop_simple_forward() | ||
test_while_loop_for_foreach() | ||
test_while_loop_nested() | ||
test_while_loop_rnn() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to call the test functions explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops I forgot to uncomment the lines above. Fixed :-)
@szha @piiswrong @eric-haibin-lin please help review this PR. |
python/mxnet/ndarray/contrib.py
Outdated
The number of elements, shape, dtype of each element in `step_output` should be consistent. | ||
The `new_loop_vars` should be consistent with `loop_vars` on each step. | ||
The `func` is variadic, and its signature should be | ||
`cond(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cond
=> func
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed :-)
# some loop_vars are inputs to `graph`, some are not | ||
name_to_loop_vars = {sym.name: sym for sym in loop_vars} | ||
# other inputs to `graph` created by cut_graph | ||
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the difference between list_outputs()[0]
and name
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like that they are equivalent. Just copied from this line in foreach.
else: | ||
sym = copy.deepcopy(name_to_input_syms[name]) | ||
# do 2), and 1) is implicitly done | ||
if id(sym) in input_id_to_loc: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does id(sym)
exist in input_id_to_loc
before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several subgraphs, more specifically, two subgraphs cond
and func
in while_loop. They may have common input symbols, so these symbols may have been added to input_id_to_loc
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the id instead of checking for the symbol directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha mx.sym.Symbol.__eq__
has been overridden and returns an NDArray instead of bool. Thus directly using sym
as keys of a dict won't work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
src/operator/control_flow.cc
Outdated
MSHADOW_TYPE_SWITCH(a.dtype(), DType, { | ||
return static_cast<bool>(_asscalar<DType>(a)); | ||
}); | ||
CHECK(false) << "Unknown dtype"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use LOG(FATAL)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed -:)
src/operator/control_flow.cc
Outdated
CHECK_EQ(arr.storage_type(), kDefaultStorage) | ||
<< "The while_loop operator doesn't support the sparse format"; | ||
// construct inputs and outputs for cond | ||
std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't condition graph always have outputs? why construct a vector with empty NDArrays?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wanted to hold the output of cond
. Is there better way to do this? Thank you!
// req[i] (step == n_iters - 1 && i not in loop_var_locs) | ||
{ | ||
size_t i = 0; | ||
for (size_t loc : var_locs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about the variables behind the loop variables? for example, the inputs of the loop are
[loop_var1, loop_var2, other_var1, other_var2]. How are other_var1 and other_var2 processed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In line 745, a dummy position is appended to var_locs
max_iterations=10, | ||
) | ||
if hybridize: | ||
model.hybridize(inline_limit=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is inline_limit
necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think people use inline_limit
by default. we should test common cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
assert result[0].asscalar() == 1 | ||
assert result[1].asscalar() == 0 | ||
assert result[2].asscalar() == 0 | ||
# Case 2.1: result should be sum([1, 2, 3 ... 100]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the cases are labelled with xx.yy. do xx and yy have special meanings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In xx.yy, I use xx
as an id of a testing template, for instance, "def case_1(**param)" is a template for various testcases. I specify different shapes and other conditions in the template, to which I refer as yy
.
assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1)) | ||
assert result_i.asscalar() == 1001 | ||
assert result_s.asscalar() == 500500 | ||
# Case 2.3: very corner case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what corner case does it test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop body is never executed. Thanks for the review, I change the comment to "Case 2.3: a corner case, in which loop body is never executed" to make it clear.
# It is for the case that inputs & outputs are the same | ||
# There are 3 outputs | ||
# There are 4 states: i, s_0, s_1, s_2 | ||
# i is used in both differentiable (take) and non-differentiable (+) occasions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take is differentiable, + is non-differentiable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My mistake. Fixed :-)
@szha @piiswrong @eric-haibin-lin Hey could you help take a look at the code? |
@zheng-da Hey I address all of the comments except the one in |
There is a problem in the current implementation. while_loop doesn't inform users the number of iterations it actually performs, so users can't know how much data in the output arrays is valid. One option is to return a mask or a scalar symbol to indicate the number of iterations. |
@zheng-da For users, an additional variable like
|
python/mxnet/ndarray/contrib.py
Outdated
@@ -191,3 +191,128 @@ def check_input(inputs, in_type, msg): | |||
if not_data_list and len(outputs) == 1: | |||
outputs = outputs[0] | |||
return (outputs, states) | |||
|
|||
|
|||
def while_loop(loop_vars, cond, func, max_iterations): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put loop_vars after func?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. Fixed :-)
python/mxnet/ndarray/contrib.py
Outdated
|
||
`func` is a user-defined function as the loop body. | ||
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. | ||
The number of elements, shape, dtype of each element in `step_output` should be consistent. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does consistent mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I change this to
In each step,
step_output
should contain the same number elements. Through all steps, the i-th element ofstep_output
should have the same shape and dtype. Also,new_loop_vars
should contain the same number of elements asloop_vars
, and the corresponding element should have the same shape and dtype.
Does this seem better?
steps = 0 | ||
outputs = [] | ||
while steps < max_iterations and \ | ||
_to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this could end before reaching max_iterations. Isn't this inconsistent with symbol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they are not consistent, and I put a warning in the docstring. Should I do some padding stuff so that they look the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zheng-da So should I pad the arrays to make them consistent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's better to do so, in my opinion. what do you think? @piiswrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, ndarray and symbol functions should give the same result for the same input. Otherwise hybridize may break
python/mxnet/ndarray/contrib.py
Outdated
try: | ||
outputs = list(ndarray.op.stack(*item) for item in zip(*outputs)) | ||
except ValueError: | ||
raise ValueError("step_outputs are inconsistent on each step") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
be explicit about which value is inconsistent. Print out the inconsistent shapes if possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
python/mxnet/symbol/contrib.py
Outdated
@@ -336,3 +336,205 @@ def check_data(inputs, in_type, msg): | |||
states = states[0] | |||
|
|||
return (outs, states) | |||
|
|||
def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interface different from ndarray?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. Fixed
python/mxnet/ndarray/contrib.py
Outdated
|
||
`max_iterations` is a scalar that defines the maximum number of iterations allowed. | ||
|
||
This function returns a list of NDArrays of length `|step_output| + |loop_vars|`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return value has different format from TF. Any specific reasons?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the outdated docstring to the following:
This function returns two lists as a tuple. The first list has the length of
|step_output|
, in which the i-th element are all i-th elements ofstep_output
from all steps, stacked along axis 0. The second list has the length of|loop_vars|
, which represents final states of loop variables.
Currently we don't have dynamic shape inference, so could not support TF-like dynamic-sized, per-time-step TensorArray
. So we split our requirement into two parts: 1. per-time-step array with max_iteration
; 2. loop variables which are of the same shape through the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the interface of this while_loop operator is close to the ONNX definition. https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop
The interface of TF while_loop requires dynamic shape and also doesn't allow efficient implementation. This interface is more flexible. If it returns ([], loop_vars)
, it's the same as the TF interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So does it return two lists separately or concated together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@piiswrong Separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function returns two lists as a tuple -> This function returns two lists.
"as a tuple" makes it sounds like the two lists are concated into a tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@piiswrong fixed
python/mxnet/symbol/contrib.py
Outdated
`max_iterations` is a scalar that defines the maximum number of iterations allowed. | ||
|
||
This function returns a list of Symbols of length `|step_output| + |loop_vars|`. | ||
The i-th element in the first `|step_output|` ones of the list represent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this sentence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, this is outdated. Does the following seem better?
This function returns two lists as a tuple. The first list has the length of
|step_output|
, in which the i-th element are all i-th elements ofstep_output
from all steps, stacked along axis 0. The second list has the length of|loop_vars|
, which represents final states of loop variables.
python/mxnet/symbol/contrib.py
Outdated
@@ -336,3 +336,205 @@ def check_data(inputs, in_type, msg): | |||
states = states[0] | |||
|
|||
return (outs, states) | |||
|
|||
def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): | |||
"""Run a while loop with user-defined computation and loop condition. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is max_iterations always required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, without dynamic shape, a user has to provide max_iterations
python/mxnet/ndarray/contrib.py
Outdated
>>> cond = lambda i, s: i <= 5 | ||
>>> func = lambda i, s: ([i + s], [i + 1, s + i]) | ||
>>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64")) | ||
>>> outputs, states = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you show the output results of this example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The results are
>>> outputs
[
[[ 1]
[ 2]
[ 4]
[ 7]
[11]
[16]]
<NDArray 6x1 @cpu(0)>]
>>> states
[
[6]
<NDArray 1 @cpu(0)>,
[16]
<NDArray 1 @cpu(0)>]
Should I put this snippet into docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
@piiswrong Can we merge this PR? |
@junrushao1994 thanks for working on this operator! I can't wait to have this available in MXNet. I took the freedom of checking out this PR and compiling MXNet on a mac to test this operator. >>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.sym.var('i'), mx.sym.var('s'))
>>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, max_iterations=10) I am evaluating this symbolic graph like this to inspect the outputs from each iteration: mx.sym.Group(outputs + states).eval(i=mx.nd.zeros((1,)), s=mx.nd.ones((1,)))[0].asnumpy().astype('int32') As expected the output is padded to In [65]: mx.sym.Group(outputs + states).eval(i=mx.nd.zeros((1,)), s=mx.nd.ones((1,)))[0].asnumpy().astype('int32')
Out[65]:
array([[ 1],
[ 2],
[ 4],
[ 7],
[ 11],
[ 16],
[-2147483648],
[ 0],
[-2147483648],
[ 0]], dtype=int32)
In [69]: mx.sym.Group(outputs + states).eval(i=mx.nd.zeros((1,)), s=mx.nd.ones((1,)))[0].asnumpy().astype('int32')
Out[69]:
array([[ 1],
[ 2],
[ 4],
[ 7],
[11],
[16],
[ 0],
[ 0],
[ 0],
[ 0]], dtype=int32) It seems that the padding values are non-deterministically numerically unstable. |
@fhieber Thank you so much for pointing out this. Yes, it is expected behavior documented here https://github.com/junrushao1994/mxnet/blob/while_loop_pr/python/mxnet/ndarray/contrib.py#L266 Acutaully I discussed this with my mentor @zheng-da several days ago, and was suggested not to fill a deterministic value in order to ensure performance. Once dynamic shape inference is out, there will be no padding thing. Currently if you wants to know the exact number of steps taken, you could add an ‘i’ into loop variables, and increments the ‘i’ on each step. Thank you so much for the detailed feedback! |
Thanks! I actually missed that comment in the documentation, sorry about that. Makes sense to keep the fill value undefined for performance reasons. I am considering this operator as the core building block for a beam search algorithm capped with max_iterations. I would then do, exactly as you suggested, keep a loop variable that counts the actual number of steps taken to slice the results from the returned, padded output. |
lgtm |
gluon.rnn.LSTMCell(ndim, prefix='rnn_')] | ||
ctxs = [mx.cpu(0), mx.gpu(0)] | ||
for cell in cells: | ||
for ctx in ctxs: | ||
for batch_size in batch_sizes: | ||
if len(get_gpus()) == 0 and ctx == mx.gpu(0): | ||
continue | ||
|
||
if isinstance(cell, gluon.rnn.GRUCell): | ||
if isinstance(cell, gluon.rnn.RNNCell): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: there's quite a bit of repetition in the below code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't touch this file. It is renamed from https://github.com/apache/incubator-mxnet/blob/master/benchmark/python/control_flow/rnn.py. Should I simplify this in this PR, or in a separate one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can coordinate with @zheng-da
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha Da and I decide that I rewrite these two files. Will push a commit later today.
out, states = F.contrib.while_loop( | ||
cond=lambda i, *_: i < self.length, | ||
func=_func, | ||
# lambda i, *s: [i + 1] + list(self.cell(s)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. Sorry for that T_T
num_batches = 20 | ||
|
||
# Imperative | ||
cell0 = copy.deepcopy(cell) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: there's quite a bit of repetition in this function
python/mxnet/ndarray/contrib.py
Outdated
The second list has the length of `|loop_vars|`, | ||
which represents final states of loop variables. | ||
|
||
Warning 1: for now, the axis 0 of all NDArrays in the first list are `max_iterations`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use markup in sphinx for warning?
.. warning::
For now, the axis 0 of all NDArrays in the first list are `max_iterations`, due to lack of dynamic shape inference.
``
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Updated
python/mxnet/ndarray/contrib.py
Outdated
|
||
Returns | ||
------- | ||
outputs: two lists, which both contains 0, 1 or more NDArrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returns
------
outputs: list of NDArray
stacked output from each step
states: list of NDArray
final state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes much more sense. Updated. Thank you!
python/mxnet/ndarray/contrib.py
Outdated
for i_th, items in enumerate(zip(*outputs), 1): | ||
# `mx.ndarray.pad` only support 4-D or 5-D inputs for now | ||
# so we could not use it. | ||
items = [x.reshape([1] + list(x.shape)) for x in items] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use expand_dims(0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
python/mxnet/ndarray/contrib.py
Outdated
items = [x.reshape([1] + list(x.shape)) for x in items] | ||
if steps != max_iterations and items: | ||
pad_shape = [max_iterations - steps] + list(items[0].shape[1: ]) | ||
pad = ndarray.empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why only pad at the end instead of allocating empty arrays upfront?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Autograd does not support inplace operations like in-place writing an NDArray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤦♂️
else: | ||
sym = copy.deepcopy(name_to_input_syms[name]) | ||
# do 2), and 1) is implicitly done | ||
if id(sym) in input_id_to_loc: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the id instead of checking for the symbol directly?
python/mxnet/symbol/contrib.py
Outdated
which represents final states of loop variables. | ||
|
||
.. warning:: | ||
For now, the axis 0 of all Symbols in the first list are `max_iterations`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there should be an empty line before and after the warning block. Also, indentation is needed for it to work. refer to sphinx doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha Sorry for the mistake. I modified the warning block according to the sphinx example just now. Could you help take a look again?
Just now I rewrote the benchmark code for both |
* Add while_loop * Avoid input/output overlap for nnvm graph cut * Add more testcases * Enhance test 4.2 * Add more complicated testcases; Add testcase for nested loop * Check unused loop_vars in while_loop * Add testcases for RNN * Make lint happy * Make lint happy * Address TODOs * Fix flaky test for while_loop * Address comments * Improve docstring * Improve error message * Add benchmark code * Update benchmarks * Allow sparse types * Make max_iterations default to None * Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md * Pad imperative while_loop so that it has the same shape with the symbolic one * Add example result into the example section * Remove unused class member * Rename unittest to test_contrib_control_flow.py * Update docstring * Update docstring * Trigger CI * Change threshold for assert_almost_equal * Trigger CI * Address comments from szha * Rewrite benchmark code * Fix sphinx warning
* Add while_loop * Avoid input/output overlap for nnvm graph cut * Add more testcases * Enhance test 4.2 * Add more complicated testcases; Add testcase for nested loop * Check unused loop_vars in while_loop * Add testcases for RNN * Make lint happy * Make lint happy * Address TODOs * Fix flaky test for while_loop * Address comments * Improve docstring * Improve error message * Add benchmark code * Update benchmarks * Allow sparse types * Make max_iterations default to None * Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md * Pad imperative while_loop so that it has the same shape with the symbolic one * Add example result into the example section * Remove unused class member * Rename unittest to test_contrib_control_flow.py * Update docstring * Update docstring * Trigger CI * Change threshold for assert_almost_equal * Trigger CI * Address comments from szha * Rewrite benchmark code * Fix sphinx warning
Description
This PR is part of the proposal of adding a set of control flow operators to MXNet. Link to proposal.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
src/operator/control_flow.cc
TODO
TODO(Junru)
in changes