Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Change RNNLayerDecoder to HybridBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Aug 14, 2018
1 parent a69a773 commit a13a681
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions tests/unittest/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def hybrid_forward(self, F, inputs, states):
states = [states1, states2]
return log_probs, states

class RNNLayerDecoder(Block):
class RNNLayerDecoder(HybridBlock):
def __init__(self, vocab_size, hidden_size, prefix=None, params=None):
super(RNNLayerDecoder, self).__init__(prefix=prefix, params=params)
self._vocab_size = vocab_size
Expand All @@ -213,7 +213,7 @@ def begin_state(self, batch_size):
def state_info(self, *args, **kwargs):
return self._rnn.state_info(*args, **kwargs)

def forward(self, inputs, states):
def hybrid_forward(self, F, inputs, states):
out, states = self._rnn(self._embed(inputs.expand_dims(0)), states)
log_probs = self._map_to_vocab(out)[0].log_softmax()
return log_probs, states
Expand All @@ -231,9 +231,6 @@ def forward(self, inputs, states):
state_info = decoder.state_info()
else:
state_info = None
if sampler_cls is HybridBeamSearchSampler and decoder_fn is RNNLayerDecoder:
# Hybrid beam search does not work on non-hybridizable object
continue
for beam_size, bos_id, eos_id, alpha, K in [(2, 1, 3, 0, 1.0), (4, 2, 3, 1.0, 5.0)]:
scorer = BeamSearchScorer(alpha=alpha, K=K)
for max_length in [2, 3]:
Expand Down

0 comments on commit a13a681

Please sign in to comment.