Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Add symbolic beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 27, 2018
1 parent 369d6f7 commit 1e1fdd3
Showing 1 changed file with 157 additions and 4 deletions.
161 changes: 157 additions & 4 deletions gluonnlp/model/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import absolute_import
from __future__ import print_function

__all__ = ['BeamSearchScorer', 'BeamSearchSampler']
__all__ = ['BeamSearchScorer', 'BeamSearchSampler', 'HybridizedBeamSearchSampler']

import numpy as np
import mxnet as mx
Expand Down Expand Up @@ -80,7 +80,7 @@ def _expand_to_beam_size(data, beam_size, batch_size, state_info=None):
Parameters
----------
data : A single NDArray or nested container with NDArrays
data : A single NDArray/Symbol or nested container with NDArrays/Symbol
Each NDArray/Symbol should have shape (N, ...) when state_info is None,
or same as the layout in state_info when it's not None.
beam_size : int
Expand All @@ -92,8 +92,8 @@ def _expand_to_beam_size(data, beam_size, batch_size, state_info=None):
When None, this method assumes that the batch axis is the first dimension.
Returns
-------
new_states : Object that contains NDArrays
Each NDArray should have shape batch_size * beam_size on the batch axis.
new_states : Object that contains NDArrays/Symbols
Each NDArray/Symbol should have shape batch_size * beam_size on the batch axis.
"""
assert not state_info or isinstance(state_info, (type(data), dict)), \
'data and state_info doesn\'t match, ' \
Expand Down Expand Up @@ -128,6 +128,15 @@ def _expand_to_beam_size(data, beam_size, batch_size, state_info=None):
return data.expand_dims(batch_axis+1)\
.broadcast_axes(axis=batch_axis+1, size=beam_size)\
.reshape(new_shape)
elif isinstance(data, mx.sym.Symbol):
if not state_info:
batch_axis = 0
else:
batch_axis = state_info['__layout__'].find('N')
new_shape = (0, ) * batch_axis + (-3, -2)
return data.expand_dims(batch_axis+1)\
.broadcast_axes(axis=batch_axis+1, size=beam_size)\
.reshape(new_shape)
else:
raise NotImplementedError

Expand Down Expand Up @@ -384,3 +393,147 @@ def __call__(self, inputs, states):
return mx.nd.round(samples).astype(np.int32),\
scores,\
mx.nd.round(valid_length).astype(np.int32)


class HybridizedBeamSearchSampler(HybridBlock):
r"""Draw samples from the decoder by beam search.
Parameters
----------
batch_size : int
The batch size.
beam_size : int
The beam size.
decoder : callable, should be hybridizable
Function of the one-step-ahead decoder, should have the form::
log_probs, new_states = decoder(step_input, states)
The log_probs, input should follow these rules:
- step_input has shape (batch_size,),
- log_probs has shape (batch_size, V),
- states and new_states have the same structure and the leading
dimension of the inner NDArrays is the batch dimension.
eos_id : int
Id of the EOS token. No other elements will be appended to the sample if it reaches eos_id.
scorer : BeamSearchScorer, default BeamSearchScorer(alpha=1.0, K=5)
The score function used in beam search.
max_length : int, default 100
The maximum search length.
vocab_size : int, default None, meaning `decoder._vocab_size`
The vocabulary size
"""
def __init__(self, batch_size, beam_size, decoder, eos_id,
scorer=BeamSearchScorer(alpha=1.0, K=5),
max_length=100,
vocab_size=None):
self._batch_size = batch_size
self._beam_size = beam_size
assert beam_size > 0,\
'beam_size must be larger than 0. Received beam_size={}'.format(beam_size)
self._decoder = decoder
self._eos_id = eos_id
assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id)
self._max_length = max_length
self._scorer = scorer
self._state_info_func = getattr(decoder, 'state_info', lambda _=None: None)
self._updater = _BeamSearchStepUpdate(beam_size=beam_size, eos_id=eos_id, scorer=scorer,
state_info=self._state_info_func())
self._updater.hybridize()
self._vocab_size = vocab_size or getattr(decoder, '_vocab_size', None)
assert self._vocab_size is not None,\
'Please provide vocab_size or define decoder._vocab_size'
assert not hasattr(decoder, '_vocab_size') or decoder._vocab_size == self._vocab_size, \
'Provided vocab_size={} is not equal to decoder._vocab_size={}'\
.format(self._vocab_size, decoder._vocab_size)

def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-differ
"""Sample by beam search.
Parameters
----------
F
inputs : NDArray or Symbol
The initial input of the decoder. Shape is (batch_size,).
states : Object that contains NDArrays or Symbols
The initial states of the decoder.
Returns
-------
samples : NDArray or Symbol
Samples draw by beam search. Shape (batch_size, beam_size, length). dtype is int32.
scores : NDArray or Symbol
Scores of the samples. Shape (batch_size, beam_size). We make sure that scores[i, :] are
in descending order.
valid_length : NDArray or Symbol
The valid length of the samples. Shape (batch_size, beam_size). dtype will be int32.
"""
batch_size = self._batch_size
beam_size = self._beam_size
vocab_size = self._vocab_size
# Tile the states and inputs to have shape (batch_size * beam_size, ...)
state_info = self._state_info_func(batch_size)
states = _expand_to_beam_size(states, beam_size=beam_size, batch_size=batch_size,
state_info=state_info)
step_input = _expand_to_beam_size(inputs, beam_size=beam_size, batch_size=batch_size)
# All beams are initialized to alive
# Generated samples are initialized to be the inputs
# Except the first beam where the scores are set to be zero, all beams have -inf scores.
# Valid length is initialized to be 1
beam_alive_mask = F.ones(shape=(batch_size, beam_size))
valid_length = F.ones(shape=(batch_size, beam_size))
if beam_size == 1:
scores = F.zeros(shape=(batch_size, beam_size))
else:
scores = F.concat(
F.zeros(shape=(batch_size, 1)),
F.full(shape=(batch_size, beam_size - 1), val=LARGE_NEGATIVE_FLOAT),
dim=1,
)
samples = step_input.reshape((batch_size, beam_size, 1))
vocab_num = F.ones([vocab_size])
batch_shift = F.arange(0, batch_size * beam_size, beam_size)

def _loop_cond(_i, _step_input, _states, _samples, _valid_length, _scores, beam_alive_mask):
return F.sum(beam_alive_mask) > 0

def _loop_func(i, step_input, states, samples, valid_length, scores, beam_alive_mask):
log_probs, new_states = self._decoder(step_input, states)
step = i + 1
samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \
self._updater(samples, valid_length, log_probs, scores, step, beam_alive_mask,
new_states, vocab_num, batch_shift)
step_input = F.relu(chosen_word_ids).reshape((-1,))
return (), (step, step_input, states, samples, valid_length, scores, beam_alive_mask)

_, step_input, states, samples, valid_length, scores, beam_alive_mask = \
F.contrib.while_loop(
cond=_loop_cond, func=_loop_func, max_iterations=self._max_length,
loop_vars=(
F.zeros(shape=(1, )), # i
step_input,
states,
samples,
valid_length,
scores,
beam_alive_mask
)
)

def _then_func():
return samples, valid_length

def _else_func():
final_word = F.where(beam_alive_mask,
F.full(shape=(batch_size, beam_size),
val=self._eos_id),
F.full(shape=(batch_size, beam_size),
val=-1))
new_samples = F.concat(samples, final_word.reshape((0, 0, 1)), dim=2)
new_valid_length = valid_length + beam_alive_mask
return new_samples, new_valid_length

samples, scores = F.contrib.cond(F.sum(beam_alive_mask) == 0, _then_func, _else_func)
return F.round(samples).astype(np.int32),\
scores,\
F.round(valid_length).astype(np.int32)

0 comments on commit 1e1fdd3

Please sign in to comment.