diff --git a/surprisal/model.py b/surprisal/model.py index d9b7815..0175b93 100644 --- a/surprisal/model.py +++ b/surprisal/model.py @@ -186,7 +186,22 @@ def surprise( mask_mask = torch.eye(n, n)[1:, :].repeat(b, 1).bool() ids_with_bos_token[mask_mask] = self.tokenizer.mask_token_id - import IPython + # below is from ckauf and neuranna? + # if "within_word_l2r" == PLL_metric: + # """ + # Future tokens belonging to the same word as the target token are masked during token inference as well. + # """ + # mask_indices = [ + # [mask_pos] + # + [ + # j + # for j in range(mask_pos + 1, effective_length + 2) + # if word_ids[j] == word_ids[mask_pos] + # ] + # if word_ids[mask_pos] is not None + # else [mask_pos] + # for mask_pos in range(effective_length + 2) + # ] raise NotImplementedError