Skip to content

Commit

Permalink
Add ability to print decoded intent
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenlugosch committed Apr 13, 2019
1 parent 3f6a562 commit 7709dfa
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ class Model(torch.nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.is_cuda = torch.cuda.is_available()
self.Sy_intent = config.Sy_intent
pretrained_model = PretrainedModel(config)
if config.pretraining_type != 0:
pretrained_model_path = os.path.join(config.folder, "pretraining", "model_state.pth")
Expand Down Expand Up @@ -545,3 +546,15 @@ def predict_intents(self, x):
predicted_intent = torch.stack(predicted_intent, dim=1)

return intent_logits, predicted_intent

def decode_intents(self, x):
_, predicted_intent = self.predict_intents(x)
intents = []
for prediction in predicted_intent:
intent = []
for idx, slot in enumerate(self.Sy_intent):
for value in self.Sy_intent[slot]:
if prediction[idx].item() == self.Sy_intent[slot][value]:
intent.append(value)
intents.append(intent)
return intents

0 comments on commit 7709dfa

Please sign in to comment.