Skip to content

Extending Self Attention #391

@Zettelkasten

Description

@Zettelkasten

I want to implement some changes to the self-attention used in the Transformer for MT, namely implement locality-sensitive hashing (https://arxiv.org/pdf/2001.04451.pdf).

Right now, self-attention is a single layer within RETURNN. While this is very convenient when using the default configuration, it is not very extensible: All options for it have been implemented as additional arguments for the layer and the code for it has become pretty messy over the time.
I could implement my changes within the layer by adding an additional parameter, but I think it might be better to not clutter the self-attention layer with even more (relatively specific) parameters.

Instead it might be nicer to implement them using existing other RETURNN Layers, similar to how encoder-decoder attention is implemented in our Trafo config.
For unmasked self-attention (where one can attend to the entire sequence, e.g. used in the encoder), I don't see an issue in implementing it completely analogous to the encoder-decoder attention:
Use three linear layers to obtain queries, keys and values. Compare all queries and keys against each other using a dot layer, and then use a softmax_over_spatial layer to turn these attention energies into attention weights. Finally use a generic_attention layer to compute a weighted sum of the attention values.

For masked self-attention (where one cannot attend to future positions, e.g. used in the decoder), there are two things to consider:

  • We have to mask all future target positions by setting the attention energies to -\infty. This could for example be done in the softmax layer (which already considers the total sequence length anyway)
  • When we are in a recurrent layer (e.g. during Trafo inference), then we would like to cache all previously computed attention keys, values and queries. The current self-attention layer does this, but it is also one of the reasons why it is messy to extend it currently: Both the recurrent and parallel case are handled somewhat differently.
    I have no idea how that should look. What would the linear layers generating attention keys and values return in a recurrent loop (wouldn't that introduce a time axis even then)? How to handle that we do not need to recompute old keys/values?

What would be the best approach to extend self-attention? Stick to changing the Returnn code of the layer? Or implement it in multiple layers, but then how do I solve the problems I mentioned?
Thanks :)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions