-
Notifications
You must be signed in to change notification settings - Fork 0
TransformerDecoder: optional positional encoding and final matmul #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| num_output: int | ||
| logits_bias: bool | ||
| share_embedding: bool | ||
| use_positional_encoding: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if, instead of being a flag, this should be a configurable module instead, which you simply replace with a noop if you don't want any positional encoding. This would allow using other positional encoding schemes other than sinusoidal as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, agree, better would be to have this more dynamic.
ConformerMHSARelPosV1._sinusoidal_pe should maybe be moved to a separate function, and then you would have positional_encoding=absolute_sinusoidal_positional_encoding as default, and None is also allowed.
| logits_bias: bool | ||
| share_embedding: bool | ||
| use_positional_encoding: bool = True | ||
| do_output_embedding_matmul: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps
| do_output_embedding_matmul: bool = True | |
| embed_outputs_to_vocab_dim: bool = True |
is clearer naming-wise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's cleaner. But I also don't like the original name. But I'm also not sure whether I like the logic at all (see my separate comment on this, why to have the out_logits at all if it is not used).
|
As a first comment (I will try to comment in more detail later): The same questions have been thought about in the RF implementation, for Transformer encoder, decoder, and very related also Conformer encoder (to make the frontend optional, etc). Current RF TransformerDecoder implementation. It already has the |
| @@ -190,13 +194,20 @@ def __init__(self, cfg: TransformerDecoderV1Config): | |||
| else: | |||
| self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias) | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realize, this sharing is weird. I would always set self.out_logits. If sharing, you can just do self.out_logits.weights = self.input_embedding.weight. That would simplify the other code.
Also, self.out_logits should always be set (be None if not used). But with my suggestion, you don't need to care about this.
And then you would also allow to have logits_bias=True with share_embedding=True.
| logits_bias: bool | ||
| share_embedding: bool | ||
| use_positional_encoding: bool = True | ||
| do_output_embedding_matmul: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is False, and not cfg.share_embedding, the out_logits are not used at all. Does it make sense to even have them then?
Changes for positional encoding and the final matrix multiplication of model output and output embedding matrix to be both optional.
This allows us to use the implementation for self-normalized LM Transformer training, where positional encoding is not required and the final matmul is replaced by another matmul in the sampling loss.
My only question is: should this be a
TransformerDecoderV2instead?