From e602ff3f22dc17b64e7280b53556a5540baa6305 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 26 Jul 2018 18:12:08 -0700 Subject: [PATCH 01/10] Add symbolic beam search --- gluonnlp/model/beam_search.py | 161 +++++++++++++++++++++++++++++++++- 1 file changed, 157 insertions(+), 4 deletions(-) diff --git a/gluonnlp/model/beam_search.py b/gluonnlp/model/beam_search.py index 7feab989c5..60083bd3d0 100644 --- a/gluonnlp/model/beam_search.py +++ b/gluonnlp/model/beam_search.py @@ -20,7 +20,7 @@ from __future__ import absolute_import from __future__ import print_function -__all__ = ['BeamSearchScorer', 'BeamSearchSampler'] +__all__ = ['BeamSearchScorer', 'BeamSearchSampler', 'HybridizedBeamSearchSampler'] import numpy as np import mxnet as mx @@ -80,7 +80,7 @@ def _expand_to_beam_size(data, beam_size, batch_size, state_info=None): Parameters ---------- - data : A single NDArray or nested container with NDArrays + data : A single NDArray/Symbol or nested container with NDArrays/Symbol Each NDArray/Symbol should have shape (N, ...) when state_info is None, or same as the layout in state_info when it's not None. beam_size : int @@ -92,8 +92,8 @@ def _expand_to_beam_size(data, beam_size, batch_size, state_info=None): When None, this method assumes that the batch axis is the first dimension. Returns ------- - new_states : Object that contains NDArrays - Each NDArray should have shape batch_size * beam_size on the batch axis. + new_states : Object that contains NDArrays/Symbols + Each NDArray/Symbol should have shape batch_size * beam_size on the batch axis. """ assert not state_info or isinstance(state_info, (type(data), dict)), \ 'data and state_info doesn\'t match, ' \ @@ -128,6 +128,15 @@ def _expand_to_beam_size(data, beam_size, batch_size, state_info=None): return data.expand_dims(batch_axis+1)\ .broadcast_axes(axis=batch_axis+1, size=beam_size)\ .reshape(new_shape) + elif isinstance(data, mx.sym.Symbol): + if not state_info: + batch_axis = 0 + else: + batch_axis = state_info['__layout__'].find('N') + new_shape = (0, ) * batch_axis + (-3, -2) + return data.expand_dims(batch_axis+1)\ + .broadcast_axes(axis=batch_axis+1, size=beam_size)\ + .reshape(new_shape) else: raise NotImplementedError @@ -384,3 +393,147 @@ def __call__(self, inputs, states): return mx.nd.round(samples).astype(np.int32),\ scores,\ mx.nd.round(valid_length).astype(np.int32) + + +class HybridizedBeamSearchSampler(HybridBlock): + r"""Draw samples from the decoder by beam search. + + Parameters + ---------- + batch_size : int + The batch size. + beam_size : int + The beam size. + decoder : callable, should be hybridizable + Function of the one-step-ahead decoder, should have the form:: + + log_probs, new_states = decoder(step_input, states) + + The log_probs, input should follow these rules: + + - step_input has shape (batch_size,), + - log_probs has shape (batch_size, V), + - states and new_states have the same structure and the leading + dimension of the inner NDArrays is the batch dimension. + eos_id : int + Id of the EOS token. No other elements will be appended to the sample if it reaches eos_id. + scorer : BeamSearchScorer, default BeamSearchScorer(alpha=1.0, K=5) + The score function used in beam search. + max_length : int, default 100 + The maximum search length. + vocab_size : int, default None, meaning `decoder._vocab_size` + The vocabulary size + """ + def __init__(self, batch_size, beam_size, decoder, eos_id, + scorer=BeamSearchScorer(alpha=1.0, K=5), + max_length=100, + vocab_size=None): + self._batch_size = batch_size + self._beam_size = beam_size + assert beam_size > 0,\ + 'beam_size must be larger than 0. Received beam_size={}'.format(beam_size) + self._decoder = decoder + self._eos_id = eos_id + assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id) + self._max_length = max_length + self._scorer = scorer + self._state_info_func = getattr(decoder, 'state_info', lambda _=None: None) + self._updater = _BeamSearchStepUpdate(beam_size=beam_size, eos_id=eos_id, scorer=scorer, + state_info=self._state_info_func()) + self._updater.hybridize() + self._vocab_size = vocab_size or getattr(decoder, '_vocab_size', None) + assert self._vocab_size is not None,\ + 'Please provide vocab_size or define decoder._vocab_size' + assert not hasattr(decoder, '_vocab_size') or decoder._vocab_size == self._vocab_size, \ + 'Provided vocab_size={} is not equal to decoder._vocab_size={}'\ + .format(self._vocab_size, decoder._vocab_size) + + def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-differ + """Sample by beam search. + + Parameters + ---------- + F + inputs : NDArray or Symbol + The initial input of the decoder. Shape is (batch_size,). + states : Object that contains NDArrays or Symbols + The initial states of the decoder. + Returns + ------- + samples : NDArray or Symbol + Samples draw by beam search. Shape (batch_size, beam_size, length). dtype is int32. + scores : NDArray or Symbol + Scores of the samples. Shape (batch_size, beam_size). We make sure that scores[i, :] are + in descending order. + valid_length : NDArray or Symbol + The valid length of the samples. Shape (batch_size, beam_size). dtype will be int32. + """ + batch_size = self._batch_size + beam_size = self._beam_size + vocab_size = self._vocab_size + # Tile the states and inputs to have shape (batch_size * beam_size, ...) + state_info = self._state_info_func(batch_size) + states = _expand_to_beam_size(states, beam_size=beam_size, batch_size=batch_size, + state_info=state_info) + step_input = _expand_to_beam_size(inputs, beam_size=beam_size, batch_size=batch_size) + # All beams are initialized to alive + # Generated samples are initialized to be the inputs + # Except the first beam where the scores are set to be zero, all beams have -inf scores. + # Valid length is initialized to be 1 + beam_alive_mask = F.ones(shape=(batch_size, beam_size)) + valid_length = F.ones(shape=(batch_size, beam_size)) + if beam_size == 1: + scores = F.zeros(shape=(batch_size, beam_size)) + else: + scores = F.concat( + F.zeros(shape=(batch_size, 1)), + F.full(shape=(batch_size, beam_size - 1), val=LARGE_NEGATIVE_FLOAT), + dim=1, + ) + samples = step_input.reshape((batch_size, beam_size, 1)) + vocab_num = F.ones([vocab_size]) + batch_shift = F.arange(0, batch_size * beam_size, beam_size) + + def _loop_cond(_i, _step_input, _states, _samples, _valid_length, _scores, beam_alive_mask): + return F.sum(beam_alive_mask) > 0 + + def _loop_func(i, step_input, states, samples, valid_length, scores, beam_alive_mask): + log_probs, new_states = self._decoder(step_input, states) + step = i + 1 + samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ + self._updater(samples, valid_length, log_probs, scores, step, beam_alive_mask, + new_states, vocab_num, batch_shift) + step_input = F.relu(chosen_word_ids).reshape((-1,)) + return (), (step, step_input, states, samples, valid_length, scores, beam_alive_mask) + + _, step_input, states, samples, valid_length, scores, beam_alive_mask = \ + F.contrib.while_loop( + cond=_loop_cond, func=_loop_func, max_iterations=self._max_length, + loop_vars=( + F.zeros(shape=(1, )), # i + step_input, + states, + samples, + valid_length, + scores, + beam_alive_mask + ) + ) + + def _then_func(): + return samples, valid_length + + def _else_func(): + final_word = F.where(beam_alive_mask, + F.full(shape=(batch_size, beam_size), + val=self._eos_id), + F.full(shape=(batch_size, beam_size), + val=-1)) + new_samples = F.concat(samples, final_word.reshape((0, 0, 1)), dim=2) + new_valid_length = valid_length + beam_alive_mask + return new_samples, new_valid_length + + samples, scores = F.contrib.cond(F.sum(beam_alive_mask) == 0, _then_func, _else_func) + return F.round(samples).astype(np.int32),\ + scores,\ + F.round(valid_length).astype(np.int32) From bab7bc9b586bb78ea042a2586093d8628d458984 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 30 Jul 2018 15:59:27 -0700 Subject: [PATCH 02/10] Update _BeamSearchStepUpdate --- gluonnlp/model/beam_search.py | 76 ++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/gluonnlp/model/beam_search.py b/gluonnlp/model/beam_search.py index 60083bd3d0..3c6aa59ad7 100644 --- a/gluonnlp/model/beam_search.py +++ b/gluonnlp/model/beam_search.py @@ -20,7 +20,7 @@ from __future__ import absolute_import from __future__ import print_function -__all__ = ['BeamSearchScorer', 'BeamSearchSampler', 'HybridizedBeamSearchSampler'] +__all__ = ['BeamSearchScorer', 'BeamSearchSampler', 'HybridBeamSearchSampler'] import numpy as np import mxnet as mx @@ -192,12 +192,13 @@ def _choose_states(F, states, state_info, indices): class _BeamSearchStepUpdate(HybridBlock): - def __init__(self, beam_size, eos_id, scorer, state_info, prefix=None, params=None): + def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, prefix=None, params=None): super(_BeamSearchStepUpdate, self).__init__(prefix, params) self._beam_size = beam_size self._eos_id = eos_id self._scorer = scorer self._state_info = state_info + self._single_step = single_step assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id) def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ @@ -207,8 +208,10 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam Parameters ---------- F - samples : NDArray or Symbol - The current samples generated by beam search. Shape (batch_size, beam_size, L) + samples : NDArray or Symbol or an empty list + The current samples generated by beam search. + An empty list when single_step is True. + (batch_size, beam_size, L) when single_step is False. valid_length : NDArray or Symbol The current valid lengths of the samples log_probs : NDArray or Symbol @@ -230,8 +233,10 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam Returns ------- - new_samples : NDArray or Symbol - The updated samples. Shape (batch_size, beam_size, L + 1) + new_samples : NDArray or Symbol or an empty list + The updated samples. + When single_step is True, it is an empty list. + When single_step is False, shape (batch_size, beam_size, L + 1) new_valid_length : NDArray or Symbol Valid lengths of the samples. Shape (batch_size, beam_size) new_scores : NDArray or Symbol @@ -269,10 +274,14 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam -F.ones_like(indices), chosen_word_ids) # Update the samples and vaild_length - new_samples = F.concat(F.take(samples.reshape(shape=(-3, 0)), - batch_beam_indices.reshape(shape=(-1,))), - chosen_word_ids.reshape(shape=(-1, 1)), dim=1)\ - .reshape(shape=(-4, -1, beam_size, 0)) + if self._single_step: + assert isinstance(samples, (tuple, list)) and not samples, "When single_step=True, please given an empty list as `samples`" + new_samples = [] + else: + new_samples = F.concat(F.take(samples.reshape(shape=(-3, 0)), + batch_beam_indices.reshape(shape=(-1,))), + chosen_word_ids.reshape(shape=(-1, 1)), dim=1)\ + .reshape(shape=(-4, -1, beam_size, 0)) new_valid_length = F.take(valid_length.reshape(shape=(-1,)), batch_beam_indices.reshape(shape=(-1,))).reshape((-1, beam_size))\ + 1 - use_prev @@ -395,7 +404,7 @@ def __call__(self, inputs, states): mx.nd.round(valid_length).astype(np.int32) -class HybridizedBeamSearchSampler(HybridBlock): +class HybridBeamSearchSampler(HybridBlock): r"""Draw samples from the decoder by beam search. Parameters @@ -404,7 +413,7 @@ class HybridizedBeamSearchSampler(HybridBlock): The batch size. beam_size : int The beam size. - decoder : callable, should be hybridizable + decoder : callable, must be hybridizable Function of the one-step-ahead decoder, should have the form:: log_probs, new_states = decoder(step_input, states) @@ -417,7 +426,7 @@ class HybridizedBeamSearchSampler(HybridBlock): dimension of the inner NDArrays is the batch dimension. eos_id : int Id of the EOS token. No other elements will be appended to the sample if it reaches eos_id. - scorer : BeamSearchScorer, default BeamSearchScorer(alpha=1.0, K=5) + scorer : BeamSearchScorer, default BeamSearchScorer(alpha=1.0, K=5), must be hybridizable The score function used in beam search. max_length : int, default 100 The maximum search length. @@ -426,8 +435,9 @@ class HybridizedBeamSearchSampler(HybridBlock): """ def __init__(self, batch_size, beam_size, decoder, eos_id, scorer=BeamSearchScorer(alpha=1.0, K=5), - max_length=100, - vocab_size=None): + max_length=100, vocab_size=None, + prefix=None, params=None): + super(HybridBeamSearchSampler, self).__init__(prefix, params) self._batch_size = batch_size self._beam_size = beam_size assert beam_size > 0,\ @@ -439,7 +449,7 @@ def __init__(self, batch_size, beam_size, decoder, eos_id, self._scorer = scorer self._state_info_func = getattr(decoder, 'state_info', lambda _=None: None) self._updater = _BeamSearchStepUpdate(beam_size=beam_size, eos_id=eos_id, scorer=scorer, - state_info=self._state_info_func()) + single_step=True, state_info=self._state_info_func()) self._updater.hybridize() self._vocab_size = vocab_size or getattr(decoder, '_vocab_size', None) assert self._vocab_size is not None,\ @@ -473,14 +483,9 @@ def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-diffe vocab_size = self._vocab_size # Tile the states and inputs to have shape (batch_size * beam_size, ...) state_info = self._state_info_func(batch_size) + step_input = _expand_to_beam_size(inputs, beam_size=beam_size, batch_size=batch_size) states = _expand_to_beam_size(states, beam_size=beam_size, batch_size=batch_size, state_info=state_info) - step_input = _expand_to_beam_size(inputs, beam_size=beam_size, batch_size=batch_size) - # All beams are initialized to alive - # Generated samples are initialized to be the inputs - # Except the first beam where the scores are set to be zero, all beams have -inf scores. - # Valid length is initialized to be 1 - beam_alive_mask = F.ones(shape=(batch_size, beam_size)) valid_length = F.ones(shape=(batch_size, beam_size)) if beam_size == 1: scores = F.zeros(shape=(batch_size, beam_size)) @@ -490,30 +495,29 @@ def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-diffe F.full(shape=(batch_size, beam_size - 1), val=LARGE_NEGATIVE_FLOAT), dim=1, ) - samples = step_input.reshape((batch_size, beam_size, 1)) + beam_alive_mask = F.ones(shape=(batch_size, beam_size)) vocab_num = F.ones([vocab_size]) batch_shift = F.arange(0, batch_size * beam_size, beam_size) - def _loop_cond(_i, _step_input, _states, _samples, _valid_length, _scores, beam_alive_mask): + def _loop_cond(_i, _step_input, _states, _valid_length, _scores, beam_alive_mask): return F.sum(beam_alive_mask) > 0 - def _loop_func(i, step_input, states, samples, valid_length, scores, beam_alive_mask): + def _loop_func(i, step_input, states, valid_length, scores, beam_alive_mask): log_probs, new_states = self._decoder(step_input, states) step = i + 1 - samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ - self._updater(samples, valid_length, log_probs, scores, step, beam_alive_mask, + _, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ + self._updater([], valid_length, log_probs, scores, step, beam_alive_mask, new_states, vocab_num, batch_shift) step_input = F.relu(chosen_word_ids).reshape((-1,)) - return (), (step, step_input, states, samples, valid_length, scores, beam_alive_mask) + return (chosen_word_ids, ), (step, step_input, states, valid_length, scores, beam_alive_mask) - _, step_input, states, samples, valid_length, scores, beam_alive_mask = \ + (samples, ), (_, _, _, valid_length, scores, beam_alive_mask) = \ F.contrib.while_loop( cond=_loop_cond, func=_loop_func, max_iterations=self._max_length, loop_vars=( F.zeros(shape=(1, )), # i step_input, states, - samples, valid_length, scores, beam_alive_mask @@ -521,7 +525,11 @@ def _loop_func(i, step_input, states, samples, valid_length, scores, beam_alive_ ) def _then_func(): - return samples, valid_length + new_samples = F.concat( + step_input.reshape((batch_size, beam_size, 1)), + samples, + dim=2) + return new_samples, valid_length def _else_func(): final_word = F.where(beam_alive_mask, @@ -529,7 +537,11 @@ def _else_func(): val=self._eos_id), F.full(shape=(batch_size, beam_size), val=-1)) - new_samples = F.concat(samples, final_word.reshape((0, 0, 1)), dim=2) + new_samples = F.concat( + step_input.reshape((batch_size, beam_size, 1)), + samples, + final_word.reshape((0, 0, 1)), + dim=2) new_valid_length = valid_length + beam_alive_mask return new_samples, new_valid_length From a6b544728a91cd1ef8b76217b3ca9c3324604188 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 31 Jul 2018 01:04:59 -0700 Subject: [PATCH 03/10] [WIP] Fix a lot of stuff Here is one thing I could not address. The `samples` are `taken` at each time stamp, it could not be expressed in `while_loop`. I totally have no idea how to deal with this. --- gluonnlp/model/beam_search.py | 65 +++++++++++++++--------------- tests/unittest/test_beam_search.py | 20 +++++++-- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/gluonnlp/model/beam_search.py b/gluonnlp/model/beam_search.py index 3c6aa59ad7..f9833fdb9c 100644 --- a/gluonnlp/model/beam_search.py +++ b/gluonnlp/model/beam_search.py @@ -201,17 +201,13 @@ def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, pre self._single_step = single_step assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id) - def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ - states, vocab_num, batch_shift): + def hybrid_forward(self, F, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ + states, vocab_num, batch_shift, samples): """ Parameters ---------- F - samples : NDArray or Symbol or an empty list - The current samples generated by beam search. - An empty list when single_step is True. - (batch_size, beam_size, L) when single_step is False. valid_length : NDArray or Symbol The current valid lengths of the samples log_probs : NDArray or Symbol @@ -230,6 +226,10 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam batch_shift : NDArray or Symbol Contains [0, beam_size, 2 * beam_size, ..., (batch_size - 1) * beam_size]. Shape (batch_size,) + samples : NDArray or Symbol or an empty list + The current samples generated by beam search. + An empty list when single_step is True. + (batch_size, beam_size, L) when single_step is False. Returns ------- @@ -385,8 +385,8 @@ def __call__(self, inputs, states): batch_shift_nd = mx.nd.arange(0, batch_size * beam_size, beam_size, ctx=ctx) step_nd = mx.nd.array([i + 1], ctx=ctx) samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ - self._updater(samples, valid_length, log_probs, scores, step_nd, beam_alive_mask, - new_states, vocab_num_nd, batch_shift_nd) + self._updater(valid_length, log_probs, scores, step_nd, beam_alive_mask, + new_states, vocab_num_nd, batch_shift_nd, samples) step_input = mx.nd.relu(chosen_word_ids).reshape((-1,)) if mx.nd.sum(beam_alive_mask).asscalar() == 0: return mx.nd.round(samples).astype(np.int32),\ @@ -430,7 +430,7 @@ class HybridBeamSearchSampler(HybridBlock): The score function used in beam search. max_length : int, default 100 The maximum search length. - vocab_size : int, default None, meaning `decoder._vocab_size` + vocab_num : int, default None, meaning `decoder._vocab_size` The vocabulary size """ def __init__(self, batch_size, beam_size, decoder, eos_id, @@ -496,56 +496,55 @@ def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-diffe dim=1, ) beam_alive_mask = F.ones(shape=(batch_size, beam_size)) - vocab_num = F.ones([vocab_size]) + vocab_num = F.full(shape=(1, ), val=vocab_size) batch_shift = F.arange(0, batch_size * beam_size, beam_size) - def _loop_cond(_i, _step_input, _states, _valid_length, _scores, beam_alive_mask): + def _loop_cond(_i, _step_input, _valid_length, _scores, beam_alive_mask, *_states): return F.sum(beam_alive_mask) > 0 - def _loop_func(i, step_input, states, valid_length, scores, beam_alive_mask): + def _loop_func(i, step_input, valid_length, scores, beam_alive_mask, *states): log_probs, new_states = self._decoder(step_input, states) step = i + 1 - _, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ - self._updater([], valid_length, log_probs, scores, step, beam_alive_mask, - new_states, vocab_num, batch_shift) - step_input = F.relu(chosen_word_ids).reshape((-1,)) - return (chosen_word_ids, ), (step, step_input, states, valid_length, scores, beam_alive_mask) + _, new_valid_length, new_scores, chosen_word_ids, new_beam_alive_mask, new_new_states = \ + self._updater(valid_length, log_probs, scores, step, beam_alive_mask, + new_states, vocab_num, batch_shift, []) + new_step_input = F.relu(chosen_word_ids).reshape((-1,)) + return (chosen_word_ids, ), (step, new_step_input, new_valid_length, new_scores, new_beam_alive_mask) + tuple(new_new_states) - (samples, ), (_, _, _, valid_length, scores, beam_alive_mask) = \ + (samples, ), (_, _, new_valid_length, new_scores, new_beam_alive_mask, _) = \ F.contrib.while_loop( cond=_loop_cond, func=_loop_func, max_iterations=self._max_length, loop_vars=( F.zeros(shape=(1, )), # i step_input, - states, valid_length, scores, beam_alive_mask - ) + ) + tuple(states) ) def _then_func(): new_samples = F.concat( step_input.reshape((batch_size, beam_size, 1)), - samples, + samples.transpose((1, 2, 0)), + F.full(shape=(batch_size, beam_size, 1), val=-1), dim=2) - return new_samples, valid_length + new_new_valid_length = new_valid_length + return new_samples, new_new_valid_length def _else_func(): final_word = F.where(beam_alive_mask, - F.full(shape=(batch_size, beam_size), - val=self._eos_id), - F.full(shape=(batch_size, beam_size), - val=-1)) + F.full(shape=(batch_size, beam_size), val=self._eos_id), + F.full(shape=(batch_size, beam_size), val=-1)) new_samples = F.concat( step_input.reshape((batch_size, beam_size, 1)), - samples, + samples.transpose((1, 2, 0)), final_word.reshape((0, 0, 1)), dim=2) - new_valid_length = valid_length + beam_alive_mask - return new_samples, new_valid_length + new_new_valid_length = new_valid_length + new_beam_alive_mask + return new_samples, new_new_valid_length - samples, scores = F.contrib.cond(F.sum(beam_alive_mask) == 0, _then_func, _else_func) - return F.round(samples).astype(np.int32),\ - scores,\ - F.round(valid_length).astype(np.int32) + new_samples, new_new_valid_length = F.contrib.cond(F.sum(new_beam_alive_mask) == 0, _then_func, _else_func) + return F.round(new_samples).astype(np.int32),\ + new_scores,\ + F.round(new_new_valid_length).astype(np.int32) diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index c988ae0a9e..1f3c0b38d4 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -7,7 +7,7 @@ from mxnet.gluon.rnn import RNNCell, RNN from numpy.testing import assert_allclose -from gluonnlp.model import BeamSearchSampler, BeamSearchScorer +from gluonnlp.model import BeamSearchSampler, BeamSearchScorer, HybridBeamSearchSampler def test_beam_search_score(): @@ -27,7 +27,9 @@ def test_beam_search_score(): @pytest.mark.seed(1) -def test_beam_search(): +@pytest.mark.parametrize('hybridize', [False, True]) +@pytest.mark.parametrize('sampler_cls', [HybridBeamSearchSampler, BeamSearchSampler]) +def test_beam_search(hybridize, sampler_cls): def _get_new_states(states, state_info, sel_beam_ids): assert not state_info or isinstance(state_info, (type(states), dict)), \ 'states and state_info don\'t match' @@ -229,12 +231,22 @@ def forward(self, inputs, states): state_info = decoder.state_info() else: state_info = None + if sampler_cls is HybridBeamSearchSampler and decoder_fn is RNNLayerDecoder: + # Hybrid beam search does not work on non-hybridizable object + continue for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: scorer = BeamSearchScorer(alpha=alpha, K=K) for max_length in [10, 20]: - sampler = BeamSearchSampler(beam_size=beam_size, decoder=decoder, eos_id=eos_id, - scorer=scorer, max_length=max_length) for batch_size in [1, 2, 5]: + if sampler_cls is HybridBeamSearchSampler: + sampler = sampler_cls(batch_size=batch_size, beam_size=beam_size, + decoder=decoder, eos_id=eos_id, vocab_size=vocab_num, + scorer=scorer, max_length=max_length) + if hybridize: + sampler.hybridize() + else: + sampler = sampler_cls(beam_size=beam_size, decoder=decoder, eos_id=eos_id, + scorer=scorer, max_length=max_length) print(type(decoder).__name__, beam_size, bos_id, eos_id, alpha, K, batch_size) states = decoder.begin_state(batch_size) inputs = mx.nd.full(shape=(batch_size,), val=bos_id) From 4a27daaa11ea1f00325c747adc11e84c2aebff3f Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 1 Aug 2018 15:18:22 -0700 Subject: [PATCH 04/10] [WIP] fix typo --- gluonnlp/model/beam_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluonnlp/model/beam_search.py b/gluonnlp/model/beam_search.py index f9833fdb9c..544359b7eb 100644 --- a/gluonnlp/model/beam_search.py +++ b/gluonnlp/model/beam_search.py @@ -533,7 +533,7 @@ def _then_func(): return new_samples, new_new_valid_length def _else_func(): - final_word = F.where(beam_alive_mask, + final_word = F.where(new_beam_alive_mask, F.full(shape=(batch_size, beam_size), val=self._eos_id), F.full(shape=(batch_size, beam_size), val=-1)) new_samples = F.concat( From 54b73401c9495a05d9ef4df726a345d15f3ba75a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 3 Aug 2018 21:18:44 -0700 Subject: [PATCH 05/10] [WIP] Submit fixes for debugging cut subgraph for Da --- gluonnlp/model/beam_search.py | 137 +++++++++++++++++++++-------- tests/unittest/test_beam_search.py | 2 +- 2 files changed, 103 insertions(+), 36 deletions(-) diff --git a/gluonnlp/model/beam_search.py b/gluonnlp/model/beam_search.py index 544359b7eb..740c1778d0 100644 --- a/gluonnlp/model/beam_search.py +++ b/gluonnlp/model/beam_search.py @@ -75,6 +75,66 @@ def hybrid_forward(self, F, log_probs, scores, step): # pylint: disable=argume return candidate_scores +def _extract_and_flatten_nested_structure(data, flattened=None): + """Flatten the structure of a nested container to a list. + + Parameters + ---------- + data : A single NDArray/Symbol or nested container with NDArrays/Symbol. + The nested container to be flattened. + flattened : list or None + The container thats holds flattened result. + Returns + ------- + structure : An integer or a nested container with integers. + The extracted structure of the container of `data`. + flattened : (optional) list + The container thats holds flattened result. + It is returned only when the input argument `flattened` is not given. + """ + if flattened is None: + flattened = [] + structure = _extract_and_flatten_nested_structure(data, flattened) + return structure, flattened + if isinstance(data, list): + return list(_extract_and_flatten_nested_structure(x, flattened) for x in data) + elif isinstance(data, tuple): + return tuple(_extract_and_flatten_nested_structure(x, flattened) for x in data) + elif isinstance(data, dict): + return {k: _extract_and_flatten_nested_structure(v) for k, v in data.items()} + elif isinstance(data, (mx.sym.Symbol, mx.nd.NDArray)): + flattened.append(data) + return len(flattened) - 1 + else: + raise NotImplementedError + + +def _reconstruct_flattened_structure(structure, flattened): + """Reconstruct the flattened list back to (possibly) nested structure. + + Parameters + ---------- + structure : An integer or a nested container with integers. + The extracted structure of the container of `data`. + flattened : list or None + The container thats holds flattened result. + Returns + ------- + data : A single NDArray/Symbol or nested container with NDArrays/Symbol. + The nested container that was flattened. + """ + if isinstance(structure, list): + return list(_reconstruct_flattened_structure(x, flattened) for x in structure) + elif isinstance(structure, tuple): + return tuple(_reconstruct_flattened_structure(x, flattened) for x in structure) + elif isinstance(structure, dict): + return {k: _reconstruct_flattened_structure(v, flattened) for k, v in structure.items()} + elif isinstance(structure, int): + return flattened[structure] + else: + raise NotImplementedError + + def _expand_to_beam_size(data, beam_size, batch_size, state_info=None): """Tile all the states to have batch_size * beam_size on the batch axis. @@ -201,13 +261,17 @@ def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, pre self._single_step = single_step assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id) - def hybrid_forward(self, F, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ - states, vocab_num, batch_shift, samples): + def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ + states, vocab_num, batch_shift): """ Parameters ---------- F + samples : NDArray or Symbol + The current samples generated by beam search. + When single_step is True, (batch_size, beam_size, max_length). + When single_step is False, (batch_size, beam_size, L). valid_length : NDArray or Symbol The current valid lengths of the samples log_probs : NDArray or Symbol @@ -226,10 +290,6 @@ def hybrid_forward(self, F, valid_length, log_probs, scores, step, beam_alive_ma batch_shift : NDArray or Symbol Contains [0, beam_size, 2 * beam_size, ..., (batch_size - 1) * beam_size]. Shape (batch_size,) - samples : NDArray or Symbol or an empty list - The current samples generated by beam search. - An empty list when single_step is True. - (batch_size, beam_size, L) when single_step is False. Returns ------- @@ -274,14 +334,12 @@ def hybrid_forward(self, F, valid_length, log_probs, scores, step, beam_alive_ma -F.ones_like(indices), chosen_word_ids) # Update the samples and vaild_length + new_samples = F.concat(F.take(samples.reshape(shape=(-3, 0)), + batch_beam_indices.reshape(shape=(-1,))), + chosen_word_ids.reshape(shape=(-1, 1)), dim=1)\ + .reshape(shape=(-4, -1, beam_size, 0)) if self._single_step: - assert isinstance(samples, (tuple, list)) and not samples, "When single_step=True, please given an empty list as `samples`" - new_samples = [] - else: - new_samples = F.concat(F.take(samples.reshape(shape=(-3, 0)), - batch_beam_indices.reshape(shape=(-1,))), - chosen_word_ids.reshape(shape=(-1, 1)), dim=1)\ - .reshape(shape=(-4, -1, beam_size, 0)) + new_samples = new_samples.slice_axis(axis=2, begin=1, end=None) new_valid_length = F.take(valid_length.reshape(shape=(-1,)), batch_beam_indices.reshape(shape=(-1,))).reshape((-1, beam_size))\ + 1 - use_prev @@ -385,8 +443,8 @@ def __call__(self, inputs, states): batch_shift_nd = mx.nd.arange(0, batch_size * beam_size, beam_size, ctx=ctx) step_nd = mx.nd.array([i + 1], ctx=ctx) samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ - self._updater(valid_length, log_probs, scores, step_nd, beam_alive_mask, - new_states, vocab_num_nd, batch_shift_nd, samples) + self._updater(samples, valid_length, log_probs, scores, step_nd, beam_alive_mask, + new_states, vocab_num_nd, batch_shift_nd) step_input = mx.nd.relu(chosen_word_ids).reshape((-1,)) if mx.nd.sum(beam_alive_mask).asscalar() == 0: return mx.nd.round(samples).astype(np.int32),\ @@ -486,47 +544,56 @@ def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-diffe step_input = _expand_to_beam_size(inputs, beam_size=beam_size, batch_size=batch_size) states = _expand_to_beam_size(states, beam_size=beam_size, batch_size=batch_size, state_info=state_info) - valid_length = F.ones(shape=(batch_size, beam_size)) + state_structure, states = _extract_and_flatten_nested_structure(states) if beam_size == 1: - scores = F.zeros(shape=(batch_size, beam_size)) + init_scores = F.zeros(shape=(batch_size, 1)) else: - scores = F.concat( + init_scores = F.concat( F.zeros(shape=(batch_size, 1)), F.full(shape=(batch_size, beam_size - 1), val=LARGE_NEGATIVE_FLOAT), dim=1, ) - beam_alive_mask = F.ones(shape=(batch_size, beam_size)) vocab_num = F.full(shape=(1, ), val=vocab_size) batch_shift = F.arange(0, batch_size * beam_size, beam_size) - def _loop_cond(_i, _step_input, _valid_length, _scores, beam_alive_mask, *_states): + def _loop_cond(i, _samples, _indices, _step_input, _valid_length, _scores, beam_alive_mask, *_states): return F.sum(beam_alive_mask) > 0 - def _loop_func(i, step_input, valid_length, scores, beam_alive_mask, *states): - log_probs, new_states = self._decoder(step_input, states) + def _loop_func(i, samples, indices, step_input, valid_length, scores, beam_alive_mask, *states): + log_probs, new_states = self._decoder(step_input, _reconstruct_flattened_structure(state_structure, states)) step = i + 1 - _, new_valid_length, new_scores, chosen_word_ids, new_beam_alive_mask, new_new_states = \ - self._updater(valid_length, log_probs, scores, step, beam_alive_mask, - new_states, vocab_num, batch_shift, []) + new_samples, new_valid_length, new_scores, chosen_word_ids, new_beam_alive_mask, new_new_states = \ + self._updater(samples, valid_length, log_probs, scores, step, beam_alive_mask, + _extract_and_flatten_nested_structure(new_states)[-1], + vocab_num, batch_shift) new_step_input = F.relu(chosen_word_ids).reshape((-1,)) - return (chosen_word_ids, ), (step, new_step_input, new_valid_length, new_scores, new_beam_alive_mask) + tuple(new_new_states) + # We are doing `new_indices = indices[1 : ] + indices[ : 1]` + new_indices = F.concat( + indices.slice_axis(axis=0, begin=1, end=None), + indices.slice_axis(axis=0, begin=0, end=1), + dim=0, + ) + return [], (step, new_samples, new_indices, new_step_input, new_valid_length, new_scores, new_beam_alive_mask) + tuple(new_new_states) - (samples, ), (_, _, new_valid_length, new_scores, new_beam_alive_mask, _) = \ + _, pad_samples, indices, _, new_valid_length, new_scores, new_beam_alive_mask = \ F.contrib.while_loop( cond=_loop_cond, func=_loop_func, max_iterations=self._max_length, loop_vars=( - F.zeros(shape=(1, )), # i - step_input, - valid_length, - scores, - beam_alive_mask + F.zeros(shape=(1, )), # i + F.zeros(shape=(batch_size, beam_size, self._max_length)), # samples + F.arange(start=0, stop=self._max_length), # indices + step_input, # step_input + F.ones(shape=(batch_size, beam_size)), # valid_length + init_scores, # scores + F.ones(shape=(batch_size, beam_size)), # beam_alive_mask ) + tuple(states) - ) + )[1][:7] # I hate Python 2 + samples = pad_samples.take(indices, axis=2) def _then_func(): new_samples = F.concat( step_input.reshape((batch_size, beam_size, 1)), - samples.transpose((1, 2, 0)), + samples, F.full(shape=(batch_size, beam_size, 1), val=-1), dim=2) new_new_valid_length = new_valid_length @@ -538,7 +605,7 @@ def _else_func(): F.full(shape=(batch_size, beam_size), val=-1)) new_samples = F.concat( step_input.reshape((batch_size, beam_size, 1)), - samples.transpose((1, 2, 0)), + samples, final_word.reshape((0, 0, 1)), dim=2) new_new_valid_length = new_valid_length + new_beam_alive_mask diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index 1f3c0b38d4..103149ef78 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -234,7 +234,7 @@ def forward(self, inputs, states): if sampler_cls is HybridBeamSearchSampler and decoder_fn is RNNLayerDecoder: # Hybrid beam search does not work on non-hybridizable object continue - for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: + for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: scorer = BeamSearchScorer(alpha=alpha, K=K) for max_length in [10, 20]: for batch_size in [1, 2, 5]: From 09850de249e6314eb6f1a06c701688400e0ac3f3 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 6 Aug 2018 17:27:02 -0700 Subject: [PATCH 06/10] Reduce search length to prevent numeral instability propagates --- tests/unittest/test_beam_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index 103149ef78..70cc1ddd1b 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -236,7 +236,7 @@ def forward(self, inputs, states): continue for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: scorer = BeamSearchScorer(alpha=alpha, K=K) - for max_length in [10, 20]: + for max_length in [3, 5]: for batch_size in [1, 2, 5]: if sampler_cls is HybridBeamSearchSampler: sampler = sampler_cls(batch_size=batch_size, beam_size=beam_size, From bb703fcbd815a4204c622607e8a23818fd8f10e0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 6 Aug 2018 17:46:24 -0700 Subject: [PATCH 07/10] Make linter happy --- gluonnlp/model/beam_search.py | 23 +++++++++++++++-------- tests/unittest/test_beam_search.py | 15 +++++++++------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/gluonnlp/model/beam_search.py b/gluonnlp/model/beam_search.py index 740c1778d0..1edbfed370 100644 --- a/gluonnlp/model/beam_search.py +++ b/gluonnlp/model/beam_search.py @@ -252,7 +252,8 @@ def _choose_states(F, states, state_info, indices): class _BeamSearchStepUpdate(HybridBlock): - def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, prefix=None, params=None): + def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, \ + prefix=None, params=None): super(_BeamSearchStepUpdate, self).__init__(prefix, params) self._beam_size = beam_size self._eos_id = eos_id @@ -261,7 +262,7 @@ def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, pre self._single_step = single_step assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id) - def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ + def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ states, vocab_num, batch_shift): """ @@ -556,13 +557,17 @@ def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-diffe vocab_num = F.full(shape=(1, ), val=vocab_size) batch_shift = F.arange(0, batch_size * beam_size, beam_size) - def _loop_cond(i, _samples, _indices, _step_input, _valid_length, _scores, beam_alive_mask, *_states): + def _loop_cond(_i, _samples, _indices, _step_input, _valid_length, _scores, \ + beam_alive_mask, *_states): return F.sum(beam_alive_mask) > 0 - def _loop_func(i, samples, indices, step_input, valid_length, scores, beam_alive_mask, *states): - log_probs, new_states = self._decoder(step_input, _reconstruct_flattened_structure(state_structure, states)) + def _loop_func(i, samples, indices, step_input, valid_length, scores, \ + beam_alive_mask, *states): + log_probs, new_states = self._decoder( + step_input, _reconstruct_flattened_structure(state_structure, states)) step = i + 1 - new_samples, new_valid_length, new_scores, chosen_word_ids, new_beam_alive_mask, new_new_states = \ + new_samples, new_valid_length, new_scores, \ + chosen_word_ids, new_beam_alive_mask, new_new_states = \ self._updater(samples, valid_length, log_probs, scores, step, beam_alive_mask, _extract_and_flatten_nested_structure(new_states)[-1], vocab_num, batch_shift) @@ -573,7 +578,8 @@ def _loop_func(i, samples, indices, step_input, valid_length, scores, beam_alive indices.slice_axis(axis=0, begin=0, end=1), dim=0, ) - return [], (step, new_samples, new_indices, new_step_input, new_valid_length, new_scores, new_beam_alive_mask) + tuple(new_new_states) + return [], (step, new_samples, new_indices, new_step_input, new_valid_length, \ + new_scores, new_beam_alive_mask) + tuple(new_new_states) _, pad_samples, indices, _, new_valid_length, new_scores, new_beam_alive_mask = \ F.contrib.while_loop( @@ -611,7 +617,8 @@ def _else_func(): new_new_valid_length = new_valid_length + new_beam_alive_mask return new_samples, new_new_valid_length - new_samples, new_new_valid_length = F.contrib.cond(F.sum(new_beam_alive_mask) == 0, _then_func, _else_func) + new_samples, new_new_valid_length = \ + F.contrib.cond(F.sum(new_beam_alive_mask) == 0, _then_func, _else_func) return F.round(new_samples).astype(np.int32),\ new_scores,\ F.round(new_new_valid_length).astype(np.int32) diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index 70cc1ddd1b..cbb74ed894 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -236,18 +236,21 @@ def forward(self, inputs, states): continue for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: scorer = BeamSearchScorer(alpha=alpha, K=K) - for max_length in [3, 5]: + for max_length in [2, 3]: for batch_size in [1, 2, 5]: if sampler_cls is HybridBeamSearchSampler: - sampler = sampler_cls(batch_size=batch_size, beam_size=beam_size, - decoder=decoder, eos_id=eos_id, vocab_size=vocab_num, - scorer=scorer, max_length=max_length) + sampler = sampler_cls(beam_size=beam_size, decoder=decoder, + eos_id=eos_id, + scorer=scorer, max_length=max_length, + vocab_size=vocab_num, batch_size=batch_size) if hybridize: sampler.hybridize() else: - sampler = sampler_cls(beam_size=beam_size, decoder=decoder, eos_id=eos_id, + sampler = sampler_cls(beam_size=beam_size, decoder=decoder, + eos_id=eos_id, scorer=scorer, max_length=max_length) - print(type(decoder).__name__, beam_size, bos_id, eos_id, alpha, K, batch_size) + print(type(decoder).__name__, beam_size, bos_id, eos_id, \ + alpha, K, batch_size) states = decoder.begin_state(batch_size) inputs = mx.nd.full(shape=(batch_size,), val=bos_id) samples, scores, valid_length = sampler(inputs, states) From a69a773e7563d004e8ce998047a79016335743cf Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 8 Aug 2018 01:14:52 -0700 Subject: [PATCH 08/10] Rename all vocab_num => vocab_size --- gluonnlp/model/beam_search.py | 22 ++++++++-------- tests/unittest/test_beam_search.py | 40 +++++++++++++++--------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/gluonnlp/model/beam_search.py b/gluonnlp/model/beam_search.py index 1edbfed370..14d1c5b0a8 100644 --- a/gluonnlp/model/beam_search.py +++ b/gluonnlp/model/beam_search.py @@ -263,7 +263,7 @@ def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, \ assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id) def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ - states, vocab_num, batch_shift): + states, vocab_size, batch_shift): """ Parameters @@ -286,7 +286,7 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam states : nested structure of NDArrays/Symbols Each NDArray/Symbol should have shape (N, ...) when state_info is None, or same as the layout in state_info when it's not None. - vocab_num : NDArray or Symbol + vocab_size : NDArray or Symbol Shape (1,) batch_shift : NDArray or Symbol Contains [0, beam_size, 2 * beam_size, ..., (batch_size - 1) * beam_size]. @@ -325,11 +325,11 @@ def hybrid_forward(self, F, samples, valid_length, log_probs, scores, step, beam finished_scores, dim=1) # Get the top K scores new_scores, indices = F.topk(candidate_scores, axis=1, k=beam_size, ret_typ='both') - use_prev = F.broadcast_greater_equal(indices, beam_size * vocab_num) - chosen_word_ids = F.broadcast_mod(indices, vocab_num) + use_prev = F.broadcast_greater_equal(indices, beam_size * vocab_size) + chosen_word_ids = F.broadcast_mod(indices, vocab_size) beam_ids = F.where(use_prev, - F.broadcast_minus(indices, beam_size * vocab_num), - F.floor(F.broadcast_div(indices, vocab_num))) + F.broadcast_minus(indices, beam_size * vocab_size), + F.floor(F.broadcast_div(indices, vocab_size))) batch_beam_indices = F.broadcast_add(beam_ids, F.expand_dims(batch_shift, axis=1)) chosen_word_ids = F.where(use_prev, -F.ones_like(indices), @@ -440,12 +440,12 @@ def __call__(self, inputs, states): samples = step_input.reshape((batch_size, beam_size, 1)) for i in range(self._max_length): log_probs, new_states = self._decoder(step_input, states) - vocab_num_nd = mx.nd.array([log_probs.shape[1]], ctx=ctx) + vocab_size_nd = mx.nd.array([log_probs.shape[1]], ctx=ctx) batch_shift_nd = mx.nd.arange(0, batch_size * beam_size, beam_size, ctx=ctx) step_nd = mx.nd.array([i + 1], ctx=ctx) samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \ self._updater(samples, valid_length, log_probs, scores, step_nd, beam_alive_mask, - new_states, vocab_num_nd, batch_shift_nd) + new_states, vocab_size_nd, batch_shift_nd) step_input = mx.nd.relu(chosen_word_ids).reshape((-1,)) if mx.nd.sum(beam_alive_mask).asscalar() == 0: return mx.nd.round(samples).astype(np.int32),\ @@ -489,7 +489,7 @@ class HybridBeamSearchSampler(HybridBlock): The score function used in beam search. max_length : int, default 100 The maximum search length. - vocab_num : int, default None, meaning `decoder._vocab_size` + vocab_size : int, default None, meaning `decoder._vocab_size` The vocabulary size """ def __init__(self, batch_size, beam_size, decoder, eos_id, @@ -554,7 +554,7 @@ def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-diffe F.full(shape=(batch_size, beam_size - 1), val=LARGE_NEGATIVE_FLOAT), dim=1, ) - vocab_num = F.full(shape=(1, ), val=vocab_size) + vocab_size = F.full(shape=(1, ), val=vocab_size) batch_shift = F.arange(0, batch_size * beam_size, beam_size) def _loop_cond(_i, _samples, _indices, _step_input, _valid_length, _scores, \ @@ -570,7 +570,7 @@ def _loop_func(i, samples, indices, step_input, valid_length, scores, \ chosen_word_ids, new_beam_alive_mask, new_new_states = \ self._updater(samples, valid_length, log_probs, scores, step, beam_alive_mask, _extract_and_flatten_nested_structure(new_states)[-1], - vocab_num, batch_shift) + vocab_size, batch_shift) new_step_input = F.relu(chosen_word_ids).reshape((-1,)) # We are doing `new_indices = indices[1 : ] + indices[ : 1]` new_indices = F.concat( diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index cbb74ed894..c33d09256b 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -108,7 +108,7 @@ def _npy_beam_search(decoder, scorer, inputs, states, eos_id, beam_size, max_len state_info = None for step in range(max_length): log_probs, states = decoder(mx.nd.array(inputs), states) - vocab_num = log_probs.shape[1] + vocab_size = log_probs.shape[1] candidate_scores = scorer(log_probs, mx.nd.array(scores), mx.nd.array([step + 1])).asnumpy() beam_done_inds = np.where(beam_done)[0] @@ -124,12 +124,12 @@ def _npy_beam_search(decoder, scorer, inputs, states, eos_id, beam_size, max_len sel_beam_ids = [] new_scores = candidate_scores[indices] for ind in indices: - if ind < beam_size * vocab_num: - sel_words.append(ind % vocab_num) - sel_beam_ids.append(ind // vocab_num) + if ind < beam_size * vocab_size: + sel_words.append(ind % vocab_size) + sel_beam_ids.append(ind // vocab_size) else: sel_words.append(-1) - sel_beam_ids.append(beam_done_inds[ind - beam_size * vocab_num]) + sel_beam_ids.append(beam_done_inds[ind - beam_size * vocab_size]) states = _get_new_states(states, state_info, sel_beam_ids) samples = np.concatenate((samples[sel_beam_ids, :], np.expand_dims(np.array(sel_words), axis=1)), axis=1) @@ -145,13 +145,13 @@ def _npy_beam_search(decoder, scorer, inputs, states, eos_id, beam_size, max_len HIDDEN_SIZE = 2 class RNNDecoder(HybridBlock): - def __init__(self, vocab_num, hidden_size, prefix=None, params=None): + def __init__(self, vocab_size, hidden_size, prefix=None, params=None): super(RNNDecoder, self).__init__(prefix=prefix, params=params) - self._vocab_num = vocab_num + self._vocab_size = vocab_size with self.name_scope(): - self._embed = nn.Embedding(input_dim=vocab_num, output_dim=hidden_size) + self._embed = nn.Embedding(input_dim=vocab_size, output_dim=hidden_size) self._rnn = RNNCell(hidden_size=hidden_size) - self._map_to_vocab = nn.Dense(vocab_num) + self._map_to_vocab = nn.Dense(vocab_size) def begin_state(self, batch_size): return self._rnn.begin_state(batch_size=batch_size, @@ -163,15 +163,15 @@ def hybrid_forward(self, F, inputs, states): return log_probs, states class RNNDecoder2(HybridBlock): - def __init__(self, vocab_num, hidden_size, prefix=None, params=None, use_tuple=False): + def __init__(self, vocab_size, hidden_size, prefix=None, params=None, use_tuple=False): super(RNNDecoder2, self).__init__(prefix=prefix, params=params) - self._vocab_num = vocab_num + self._vocab_size = vocab_size self._use_tuple = use_tuple with self.name_scope(): - self._embed = nn.Embedding(input_dim=vocab_num, output_dim=hidden_size) + self._embed = nn.Embedding(input_dim=vocab_size, output_dim=hidden_size) self._rnn1 = RNNCell(hidden_size=hidden_size) self._rnn2 = RNNCell(hidden_size=hidden_size) - self._map_to_vocab = nn.Dense(vocab_num) + self._map_to_vocab = nn.Dense(vocab_size) def begin_state(self, batch_size): ret = [self._rnn1.begin_state(batch_size=batch_size, @@ -198,13 +198,13 @@ def hybrid_forward(self, F, inputs, states): return log_probs, states class RNNLayerDecoder(Block): - def __init__(self, vocab_num, hidden_size, prefix=None, params=None): + def __init__(self, vocab_size, hidden_size, prefix=None, params=None): super(RNNLayerDecoder, self).__init__(prefix=prefix, params=params) - self._vocab_num = vocab_num + self._vocab_size = vocab_size with self.name_scope(): - self._embed = nn.Embedding(input_dim=vocab_num, output_dim=hidden_size) + self._embed = nn.Embedding(input_dim=vocab_size, output_dim=hidden_size) self._rnn = RNN(hidden_size=hidden_size, num_layers=1, activation='tanh') - self._map_to_vocab = nn.Dense(vocab_num, flatten=False) + self._map_to_vocab = nn.Dense(vocab_size, flatten=False) def begin_state(self, batch_size): return self._rnn.begin_state(batch_size=batch_size, @@ -219,12 +219,12 @@ def forward(self, inputs, states): return log_probs, states # Begin Testing - for vocab_num in [4, 8]: + for vocab_size in [4, 8]: for decoder_fn in [RNNDecoder, functools.partial(RNNDecoder2, use_tuple=False), functools.partial(RNNDecoder2, use_tuple=True), RNNLayerDecoder]: - decoder = decoder_fn(vocab_num=vocab_num, hidden_size=HIDDEN_SIZE) + decoder = decoder_fn(vocab_size=vocab_size, hidden_size=HIDDEN_SIZE) decoder.hybridize() decoder.initialize() if hasattr(decoder, 'state_info'): @@ -242,7 +242,7 @@ def forward(self, inputs, states): sampler = sampler_cls(beam_size=beam_size, decoder=decoder, eos_id=eos_id, scorer=scorer, max_length=max_length, - vocab_size=vocab_num, batch_size=batch_size) + vocab_size=vocab_size, batch_size=batch_size) if hybridize: sampler.hybridize() else: From a13a6810a2fc0b5bf528a354feacceaa8a2f50e6 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 14 Aug 2018 14:03:41 -0700 Subject: [PATCH 09/10] Change RNNLayerDecoder to HybridBlock --- tests/unittest/test_beam_search.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index c33d09256b..65ac809c50 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -197,7 +197,7 @@ def hybrid_forward(self, F, inputs, states): states = [states1, states2] return log_probs, states - class RNNLayerDecoder(Block): + class RNNLayerDecoder(HybridBlock): def __init__(self, vocab_size, hidden_size, prefix=None, params=None): super(RNNLayerDecoder, self).__init__(prefix=prefix, params=params) self._vocab_size = vocab_size @@ -213,7 +213,7 @@ def begin_state(self, batch_size): def state_info(self, *args, **kwargs): return self._rnn.state_info(*args, **kwargs) - def forward(self, inputs, states): + def hybrid_forward(self, F, inputs, states): out, states = self._rnn(self._embed(inputs.expand_dims(0)), states) log_probs = self._map_to_vocab(out)[0].log_softmax() return log_probs, states @@ -231,9 +231,6 @@ def forward(self, inputs, states): state_info = decoder.state_info() else: state_info = None - if sampler_cls is HybridBeamSearchSampler and decoder_fn is RNNLayerDecoder: - # Hybrid beam search does not work on non-hybridizable object - continue for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: scorer = BeamSearchScorer(alpha=alpha, K=K) for max_length in [2, 3]: From 016acacd9971108b8806f017bf3fb0e561da88b8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 15 Aug 2018 20:20:12 +0800 Subject: [PATCH 10/10] Symbol[0] => Symbol.squeeze(axis=0) --- tests/unittest/test_beam_search.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index 65ac809c50..511ab937d0 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -215,11 +215,11 @@ def state_info(self, *args, **kwargs): def hybrid_forward(self, F, inputs, states): out, states = self._rnn(self._embed(inputs.expand_dims(0)), states) - log_probs = self._map_to_vocab(out)[0].log_softmax() + log_probs = self._map_to_vocab(out).squeeze(axis=0).log_softmax() return log_probs, states # Begin Testing - for vocab_size in [4, 8]: + for vocab_size in [2, 3]: for decoder_fn in [RNNDecoder, functools.partial(RNNDecoder2, use_tuple=False), functools.partial(RNNDecoder2, use_tuple=True), @@ -234,7 +234,7 @@ def hybrid_forward(self, F, inputs, states): for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: scorer = BeamSearchScorer(alpha=alpha, K=K) for max_length in [2, 3]: - for batch_size in [1, 2, 5]: + for batch_size in [1, 5]: if sampler_cls is HybridBeamSearchSampler: sampler = sampler_cls(beam_size=beam_size, decoder=decoder, eos_id=eos_id,