Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-626] Add while_loop #11566

Merged
merged 31 commits into from
Jul 19, 2018
Merged

[MXNET-626] Add while_loop #11566

merged 31 commits into from
Jul 19, 2018

Conversation

junrushao
Copy link
Member

@junrushao junrushao commented Jul 5, 2018

Description

This PR is part of the proposal of adding a set of control flow operators to MXNet. Link to proposal.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Add while_loop operator in src/operator/control_flow.cc

TODO

  • Add a JIRA_ID
  • Add more testcases
  • Re-examine all TODO(Junru) in changes

@junrushao junrushao requested a review from szha as a code owner July 5, 2018 18:39
@junrushao junrushao changed the title [WIP] Add while_loop [MXNET-626] [WIP] Add while_loop Jul 6, 2018
@junrushao junrushao requested a review from anirudh2290 as a code owner July 8, 2018 08:26
@@ -0,0 +1,969 @@
# Licensed to the Apache Software Foundation (ASF) under one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rename this file to test_control_flow.py

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :-)

test_while_loop_simple_forward()
test_while_loop_for_foreach()
test_while_loop_nested()
test_while_loop_rnn()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to call the test functions explicitly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops I forgot to uncomment the lines above. Fixed :-)

@zheng-da
Copy link
Contributor

zheng-da commented Jul 8, 2018

@szha @piiswrong @eric-haibin-lin please help review this PR.

The number of elements, shape, dtype of each element in `step_output` should be consistent.
The `new_loop_vars` should be consistent with `loop_vars` on each step.
The `func` is variadic, and its signature should be
`cond(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cond => func.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :-)

# some loop_vars are inputs to `graph`, some are not
name_to_loop_vars = {sym.name: sym for sym in loop_vars}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between list_outputs()[0] and name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like that they are equivalent. Just copied from this line in foreach.

else:
sym = copy.deepcopy(name_to_input_syms[name])
# do 2), and 1) is implicitly done
if id(sym) in input_id_to_loc:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does id(sym) exist in input_id_to_loc before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several subgraphs, more specifically, two subgraphs cond and func in while_loop. They may have common input symbols, so these symbols may have been added to input_id_to_loc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the id instead of checking for the symbol directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha mx.sym.Symbol.__eq__ has been overridden and returns an NDArray instead of bool. Thus directly using sym as keys of a dict won't work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

MSHADOW_TYPE_SWITCH(a.dtype(), DType, {
return static_cast<bool>(_asscalar<DType>(a));
});
CHECK(false) << "Unknown dtype";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use LOG(FATAL)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed -:)

CHECK_EQ(arr.storage_type(), kDefaultStorage)
<< "The while_loop operator doesn't support the sparse format";
// construct inputs and outputs for cond
std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't condition graph always have outputs? why construct a vector with empty NDArrays?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to hold the output of cond. Is there better way to do this? Thank you!

// req[i] (step == n_iters - 1 && i not in loop_var_locs)
{
size_t i = 0;
for (size_t loc : var_locs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about the variables behind the loop variables? for example, the inputs of the loop are
[loop_var1, loop_var2, other_var1, other_var2]. How are other_var1 and other_var2 processed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line 745, a dummy position is appended to var_locs

max_iterations=10,
)
if hybridize:
model.hybridize(inline_limit=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is inline_limit necessary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think people use inline_limit by default. we should test common cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

assert result[0].asscalar() == 1
assert result[1].asscalar() == 0
assert result[2].asscalar() == 0
# Case 2.1: result should be sum([1, 2, 3 ... 100])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cases are labelled with xx.yy. do xx and yy have special meanings?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In xx.yy, I use xx as an id of a testing template, for instance, "def case_1(**param)" is a template for various testcases. I specify different shapes and other conditions in the template, to which I refer as yy.

assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1))
assert result_i.asscalar() == 1001
assert result_s.asscalar() == 500500
# Case 2.3: very corner case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what corner case does it test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop body is never executed. Thanks for the review, I change the comment to "Case 2.3: a corner case, in which loop body is never executed" to make it clear.

# It is for the case that inputs & outputs are the same
# There are 3 outputs
# There are 4 states: i, s_0, s_1, s_2
# i is used in both differentiable (take) and non-differentiable (+) occasions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take is differentiable, + is non-differentiable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake. Fixed :-)

@junrushao junrushao changed the title [MXNET-626] [WIP] Add while_loop [MXNET-626] Add while_loop Jul 9, 2018
@junrushao
Copy link
Member Author

@szha @piiswrong @eric-haibin-lin Hey could you help take a look at the code?

@junrushao
Copy link
Member Author

junrushao commented Jul 9, 2018

@zheng-da Hey I address all of the comments except the one in WhileLoopComputeExCPU, in which I am not sure there is better solution. Would you mind taking a look?

@zheng-da
Copy link
Contributor

zheng-da commented Jul 9, 2018

There is a problem in the current implementation. while_loop doesn't inform users the number of iterations it actually performs, so users can't know how much data in the output arrays is valid. One option is to return a mask or a scalar symbol to indicate the number of iterations.

@junrushao
Copy link
Member Author

junrushao commented Jul 9, 2018

@zheng-da For users, an additional variable like i in loop variables could perfectly carry information about number of steps it takes, and intuitively matches how a Python while-loop works.
For developers, the number of steps taken is indicated in WhileLoopState::n_iterations, I am not sure how to access it from outside. Will this be a problem in future development?
If we consider change the API, i.e. add an extra symbol indicating number of steps taken, it will be some non-trivial discussion like naming, dtype and ctx of such Symbol, and all those stuff may be hard-coded in mxnet python lib. Not sure if there is better or more concise way to do this.

There is a problem in the current implementation. while_loop doesn't inform users the number of iterations it actually performs, so users can't know how much data in the output arrays is valid. One option is to return a mask or a scalar symbol to indicate the number of iterations.

@@ -191,3 +191,128 @@ def check_input(inputs, in_type, msg):
if not_data_list and len(outputs) == 1:
outputs = outputs[0]
return (outputs, states)


def while_loop(loop_vars, cond, func, max_iterations):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put loop_vars after func?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Fixed :-)


`func` is a user-defined function as the loop body.
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step.
The number of elements, shape, dtype of each element in `step_output` should be consistent.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does consistent mean?

Copy link
Member Author

@junrushao junrushao Jul 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I change this to

In each step, step_output should contain the same number elements. Through all steps, the i-th element of step_output should have the same shape and dtype. Also, new_loop_vars should contain the same number of elements as loop_vars, and the corresponding element should have the same shape and dtype.

Does this seem better?

steps = 0
outputs = []
while steps < max_iterations and \
_to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this could end before reaching max_iterations. Isn't this inconsistent with symbol?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are not consistent, and I put a warning in the docstring. Should I do some padding stuff so that they look the same?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so.

Copy link
Member Author

@junrushao junrushao Jul 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zheng-da So should I pad the arrays to make them consistent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to do so, in my opinion. what do you think? @piiswrong

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ndarray and symbol functions should give the same result for the same input. Otherwise hybridize may break

try:
outputs = list(ndarray.op.stack(*item) for item in zip(*outputs))
except ValueError:
raise ValueError("step_outputs are inconsistent on each step")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be explicit about which value is inconsistent. Print out the inconsistent shapes if possible

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@@ -336,3 +336,205 @@ def check_data(inputs, in_type, msg):
states = states[0]

return (outs, states)

def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interface different from ndarray?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Fixed


`max_iterations` is a scalar that defines the maximum number of iterations allowed.

This function returns a list of NDArrays of length `|step_output| + |loop_vars|`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return value has different format from TF. Any specific reasons?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the outdated docstring to the following:

This function returns two lists as a tuple. The first list has the length of |step_output|, in which the i-th element are all i-th elements of step_output from all steps, stacked along axis 0. The second list has the length of |loop_vars|, which represents final states of loop variables.

Currently we don't have dynamic shape inference, so could not support TF-like dynamic-sized, per-time-step TensorArray. So we split our requirement into two parts: 1. per-time-step array with max_iteration; 2. loop variables which are of the same shape through the loop.

@zheng-da

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the interface of this while_loop operator is close to the ONNX definition. https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop
The interface of TF while_loop requires dynamic shape and also doesn't allow efficient implementation. This interface is more flexible. If it returns ([], loop_vars), it's the same as the TF interface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does it return two lists separately or concated together?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong Separately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns two lists as a tuple -> This function returns two lists.

"as a tuple" makes it sounds like the two lists are concated into a tuple

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong fixed

`max_iterations` is a scalar that defines the maximum number of iterations allowed.

This function returns a list of Symbols of length `|step_output| + |loop_vars|`.
The i-th element in the first `|step_output|` ones of the list represent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this sentence.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, this is outdated. Does the following seem better?

This function returns two lists as a tuple. The first list has the length of |step_output|, in which the i-th element are all i-th elements of step_output from all steps, stacked along axis 0. The second list has the length of |loop_vars|, which represents final states of loop variables.

@@ -336,3 +336,205 @@ def check_data(inputs, in_type, msg):
states = states[0]

return (outs, states)

def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"):
"""Run a while loop with user-defined computation and loop condition.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is max_iterations always required?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, without dynamic shape, a user has to provide max_iterations

>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64"))
>>> outputs, states = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you show the output results of this example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The results are

>>> outputs
[
 [[ 1]
  [ 2]
  [ 4]
  [ 7]
  [11]
  [16]]
 <NDArray 6x1 @cpu(0)>]
>>> states
[
 [6]
 <NDArray 1 @cpu(0)>,
 [16]
 <NDArray 1 @cpu(0)>]

Should I put this snippet into docstring?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@junrushao
Copy link
Member Author

@piiswrong Can we merge this PR?

@fhieber
Copy link
Contributor

fhieber commented Jul 14, 2018

@junrushao1994 thanks for working on this operator! I can't wait to have this available in MXNet. I took the freedom of checking out this PR and compiling MXNet on a mac to test this operator.
I tried using the example described in the docstring:

>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.sym.var('i'), mx.sym.var('s'))
>>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, max_iterations=10)

I am evaluating this symbolic graph like this to inspect the outputs from each iteration:

mx.sym.Group(outputs + states).eval(i=mx.nd.zeros((1,)), s=mx.nd.ones((1,)))[0].asnumpy().astype('int32')

As expected the output is padded to max_iterations, but the padding values are non-deterministic when repeatedly calling the above line:

In [65]: mx.sym.Group(outputs + states).eval(i=mx.nd.zeros((1,)), s=mx.nd.ones((1,)))[0].asnumpy().astype('int32')
Out[65]:
array([[          1],
       [          2],
       [          4],
       [          7],
       [         11],
       [         16],
       [-2147483648],
       [          0],
       [-2147483648],
       [          0]], dtype=int32)

In [69]: mx.sym.Group(outputs + states).eval(i=mx.nd.zeros((1,)), s=mx.nd.ones((1,)))[0].asnumpy().astype('int32')
Out[69]:
array([[ 1],
       [ 2],
       [ 4],
       [ 7],
       [11],
       [16],
       [ 0],
       [ 0],
       [ 0],
       [ 0]], dtype=int32)

It seems that the padding values are non-deterministically numerically unstable.

@junrushao
Copy link
Member Author

@fhieber Thank you so much for pointing out this. Yes, it is expected behavior documented here https://github.com/junrushao1994/mxnet/blob/while_loop_pr/python/mxnet/ndarray/contrib.py#L266

Acutaully I discussed this with my mentor @zheng-da several days ago, and was suggested not to fill a deterministic value in order to ensure performance. Once dynamic shape inference is out, there will be no padding thing.

Currently if you wants to know the exact number of steps taken, you could add an ‘i’ into loop variables, and increments the ‘i’ on each step.

Thank you so much for the detailed feedback!

@fhieber
Copy link
Contributor

fhieber commented Jul 14, 2018

Thanks! I actually missed that comment in the documentation, sorry about that. Makes sense to keep the fill value undefined for performance reasons.

I am considering this operator as the core building block for a beam search algorithm capped with max_iterations. I would then do, exactly as you suggested, keep a loop variable that counts the actual number of steps taken to slice the results from the returned, padded output.

@junrushao junrushao mentioned this pull request Jul 17, 2018
10 tasks
@eric-haibin-lin
Copy link
Member

lgtm

gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
ctxs = [mx.cpu(0), mx.gpu(0)]
for cell in cells:
for ctx in ctxs:
for batch_size in batch_sizes:
if len(get_gpus()) == 0 and ctx == mx.gpu(0):
continue

if isinstance(cell, gluon.rnn.GRUCell):
if isinstance(cell, gluon.rnn.RNNCell):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there's quite a bit of repetition in the below code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't touch this file. It is renamed from https://github.com/apache/incubator-mxnet/blob/master/benchmark/python/control_flow/rnn.py. Should I simplify this in this PR, or in a separate one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can coordinate with @zheng-da

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha Da and I decide that I rewrite these two files. Will push a commit later today.

out, states = F.contrib.while_loop(
cond=lambda i, *_: i < self.length,
func=_func,
# lambda i, *s: [i + 1] + list(self.cell(s)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code

Copy link
Member Author

@junrushao junrushao Jul 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. Sorry for that T_T

num_batches = 20

# Imperative
cell0 = copy.deepcopy(cell)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there's quite a bit of repetition in this function

The second list has the length of `|loop_vars|`,
which represents final states of loop variables.

Warning 1: for now, the axis 0 of all NDArrays in the first list are `max_iterations`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use markup in sphinx for warning?

.. warning::
   For now, the axis 0 of all NDArrays in the first list are `max_iterations`, due to lack of dynamic shape inference.
``

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Updated


Returns
-------
outputs: two lists, which both contains 0, 1 or more NDArrays.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns
------
outputs: list of NDArray
    stacked output from each step
states: list of NDArray
    final state

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes much more sense. Updated. Thank you!

for i_th, items in enumerate(zip(*outputs), 1):
# `mx.ndarray.pad` only support 4-D or 5-D inputs for now
# so we could not use it.
items = [x.reshape([1] + list(x.shape)) for x in items]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use expand_dims(0)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

items = [x.reshape([1] + list(x.shape)) for x in items]
if steps != max_iterations and items:
pad_shape = [max_iterations - steps] + list(items[0].shape[1: ])
pad = ndarray.empty(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only pad at the end instead of allocating empty arrays upfront?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Autograd does not support inplace operations like in-place writing an NDArray

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦‍♂️

else:
sym = copy.deepcopy(name_to_input_syms[name])
# do 2), and 1) is implicitly done
if id(sym) in input_id_to_loc:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the id instead of checking for the symbol directly?

@zheng-da zheng-da mentioned this pull request Jul 18, 2018
5 tasks
which represents final states of loop variables.

.. warning::
For now, the axis 0 of all Symbols in the first list are `max_iterations`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be an empty line before and after the warning block. Also, indentation is needed for it to work. refer to sphinx doc

Copy link
Member Author

@junrushao junrushao Jul 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha Sorry for the mistake. I modified the warning block according to the sphinx example just now. Could you help take a look again?

@junrushao
Copy link
Member Author

Just now I rewrote the benchmark code for both foreach and while_loop in a single file.

@szha szha merged commit 54632bc into apache:master Jul 19, 2018
KellenSunderland pushed a commit to KellenSunderland/incubator-mxnet that referenced this pull request Jul 21, 2018
* Add while_loop

* Avoid input/output overlap for nnvm graph cut

* Add more testcases

* Enhance test 4.2

* Add more complicated testcases; Add testcase for nested loop

* Check unused loop_vars in while_loop

* Add testcases for RNN

* Make lint happy

* Make lint happy

* Address TODOs

* Fix flaky test for while_loop

* Address comments

* Improve docstring

* Improve error message

* Add benchmark code

* Update benchmarks

* Allow sparse types

* Make max_iterations default to None

* Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md

* Pad imperative while_loop so that it has the same shape with the symbolic one

* Add example result into the example section

* Remove unused class member

* Rename unittest to test_contrib_control_flow.py

* Update docstring

* Update docstring

* Trigger CI

* Change threshold for assert_almost_equal

* Trigger CI

* Address comments from szha

* Rewrite benchmark code

* Fix sphinx warning
@junrushao junrushao deleted the while_loop_pr branch July 27, 2018 01:14
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* Add while_loop

* Avoid input/output overlap for nnvm graph cut

* Add more testcases

* Enhance test 4.2

* Add more complicated testcases; Add testcase for nested loop

* Check unused loop_vars in while_loop

* Add testcases for RNN

* Make lint happy

* Make lint happy

* Address TODOs

* Fix flaky test for while_loop

* Address comments

* Improve docstring

* Improve error message

* Add benchmark code

* Update benchmarks

* Allow sparse types

* Make max_iterations default to None

* Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md

* Pad imperative while_loop so that it has the same shape with the symbolic one

* Add example result into the example section

* Remove unused class member

* Rename unittest to test_contrib_control_flow.py

* Update docstring

* Update docstring

* Trigger CI

* Change threshold for assert_almost_equal

* Trigger CI

* Address comments from szha

* Rewrite benchmark code

* Fix sphinx warning
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants