Bug description
Hi Sebastian,
I think that it is not a bug but possible enhancement - to apply mask we have two steps now:
- Creating lower triangular matrix ones and zeros:
mask_simple = torch.tril(torch.ones(context_length, context_length))
- Multiply attention matrix with triangular matrix:
masked_simple = attn_weights * mask_simple
However this function (torch.tril) can be applied directly to attention matrix to get the same result:
Thank you.
What operating system are you using?
None
Where do you run your code?
None
Environment