-
Notifications
You must be signed in to change notification settings - Fork 533
Conversation
b9b3c90
to
1e1fdd3
Compare
Job PR-233/7 is complete. |
Great! Can you make sure the |
@leezu Hey, thank you for the suggestions! So do you prefer I create a new method Thanks! |
Would it make sense to make
Also, is there any advantage of |
@leezu
|
Good job. But for translation task, the batch_size can change across the iterations, so we cannot simply treat it as an extra parameter. |
@szhengac It is mandatory for now because static shape inference is required. This issue could be alleviated once symbolic shape is realized. |
Codecov Report
@@ Coverage Diff @@
## master #233 +/- ##
=======================================
Coverage 74.69% 74.69%
=======================================
Files 83 83
Lines 7682 7682
Branches 1315 1315
=======================================
Hits 5738 5738
Misses 1675 1675
Partials 269 269 Continue to review full report at Codecov.
|
Job PR-233/8 is complete. |
Question: do you guys prefer the name |
vocab_size, which is a term already used elsewhere in gluonnlp. |
@junrushao1994 Yes, you can change it to inherit from Block. Just change the |
@szha Seems that there are already many |
50aaeb2
to
279953e
Compare
279953e
to
7f00f1b
Compare
It seems that numeral instability and uncertainty would cause unittest to fail. For example, when there are close (or equal) values, |
@junrushao1994 our test environment depends on a specific nightly version. Check under env/ and see if you need to update the date of the nightly build. |
@zheng-da finds the bug, and just now we submit the fix here: apache/mxnet#12078 |
8dc047f
to
3bd29db
Compare
The most recent commit fails CI test for the following reason:
Is that an incompatibility issue with MXNet nightly build, or cuDNN, or anything else? @szha |
@junrushao1994 yeah, I'm working on it as part of #264 |
3bd29db
to
6dec88d
Compare
Job PR-233/27 is complete. |
c83b2fe
to
9a4521b
Compare
Job PR-233/28 is complete. |
@szha This PR has passed CI test, so would you like to help review the code, especially docstring (my English is pretty bad) Thank you! |
tests/unittest/test_beam_search.py
Outdated
@@ -196,13 +198,13 @@ def hybrid_forward(self, F, inputs, states): | |||
return log_probs, states | |||
|
|||
class RNNLayerDecoder(Block): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be a hybrid block now
tests/unittest/test_beam_search.py
Outdated
for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: | ||
if sampler_cls is HybridBeamSearchSampler and decoder_fn is RNNLayerDecoder: | ||
# Hybrid beam search does not work on non-hybridizable object | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to skip because RNNLayerDecoder can be a hybrid block.
Here is one thing I could not address. The `samples` are `taken` at each time stamp, it could not be expressed in `while_loop`. I totally have no idea how to deal with this.
9a4521b
to
a13a681
Compare
1ae164a
to
016acac
Compare
Job PR-233/32 is complete. |
Thank you guys all the help and suggestions! |
* Add symbolic beam search * Update _BeamSearchStepUpdate * [WIP] Fix a lot of stuff Here is one thing I could not address. The `samples` are `taken` at each time stamp, it could not be expressed in `while_loop`. I totally have no idea how to deal with this. * [WIP] fix typo * [WIP] Submit fixes for debugging cut subgraph for Da * Reduce search length to prevent numeral instability propagates * Make linter happy * Rename all vocab_num => vocab_size * Change RNNLayerDecoder to HybridBlock * Symbol[0] => Symbol.squeeze(axis=0)
Description
Symbolic beam search is made possible after enabling control flow operators
mx.sym.contrib.while_loop
(apache/mxnet#11566) andmx.sym.contrib.cond
(apache/mxnet#11760). In this PR, we create a classHybridBeamSearchSampler
, which could be hybridized to perform beam search.Checklist
Essentials
Changes
model.beam_search.HybridBeamSearchSampler
model.beam_search._expand_to_beam_size
to accept Symbols_extract_and_flatten_nested_structure
and_reconstruct_flattened_structure
to flatten and unflatten structure used in decodersBeamSearchSampler
andHybridBeamSearchSampler
to workaround failing testcases causes bytopk
.TODO
model.beam_search.HybridBeamSearchSampler
vocab_num
tovocab_size
Comments
HybridBeamSearchSampler
requires two extra arguments,batch_size
andvocab_size
, compared withBeamSearchSampler