Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

translate.py not consistent #19

Open
dzjxzyd opened this issue Feb 27, 2024 · 2 comments
Open

translate.py not consistent #19

dzjxzyd opened this issue Feb 27, 2024 · 2 comments

Comments

@dzjxzyd
Copy link

dzjxzyd commented Feb 27, 2024

Line 60 in Translate.py file:

decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1.type(torch.int).type_as(source_mask).to(device)

here the decoder_mask create a mask matrix is

[0,1,1
0,0,1
0,0,0] , only the lower triangle is masked; which is different from the decoder_mask in dataset.py file;

def causal_mask(size):
# torch.triu is the funcion that give me all the value 1 above the diagonal
# diagonal = 1 means above the diagonal is 1, if it is 0, then the values in both the triangle and the diagonal are all 0
mask = torch.triu(torch.ones(1,size,size), diagonal=1).type(torch.int)
return mask == 0

in causak_mask function, the upper triangle is masked

@dzjxzyd
Copy link
Author

dzjxzyd commented Feb 27, 2024

WeChat2169b2f0b80fa9c155759775c94a7045

@dzjxzyd
Copy link
Author

dzjxzyd commented Feb 27, 2024

the above prediction is made by run_validation:
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: print(msg), 0, None, num_examples=2)

then i change the line 60 in translate.py file as
decoder_mask = (torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1)==0).type(torch.int).type_as(source_mask).to(device) # make the opposite of the decoder_mask
the prediction is consistent
the bottom one use the original code at line 60
torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1).type(torch.int).type_as(source_mask).to(device)
which is not consistent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant