diff --git a/tests/unittest/test_beam_search.py b/tests/unittest/test_beam_search.py index c33d09256b..65ac809c50 100644 --- a/tests/unittest/test_beam_search.py +++ b/tests/unittest/test_beam_search.py @@ -197,7 +197,7 @@ def hybrid_forward(self, F, inputs, states): states = [states1, states2] return log_probs, states - class RNNLayerDecoder(Block): + class RNNLayerDecoder(HybridBlock): def __init__(self, vocab_size, hidden_size, prefix=None, params=None): super(RNNLayerDecoder, self).__init__(prefix=prefix, params=params) self._vocab_size = vocab_size @@ -213,7 +213,7 @@ def begin_state(self, batch_size): def state_info(self, *args, **kwargs): return self._rnn.state_info(*args, **kwargs) - def forward(self, inputs, states): + def hybrid_forward(self, F, inputs, states): out, states = self._rnn(self._embed(inputs.expand_dims(0)), states) log_probs = self._map_to_vocab(out)[0].log_softmax() return log_probs, states @@ -231,9 +231,6 @@ def forward(self, inputs, states): state_info = decoder.state_info() else: state_info = None - if sampler_cls is HybridBeamSearchSampler and decoder_fn is RNNLayerDecoder: - # Hybrid beam search does not work on non-hybridizable object - continue for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]: scorer = BeamSearchScorer(alpha=alpha, K=K) for max_length in [2, 3]: