You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jan 21, 2025. It is now read-only.
Starting work on a new mesh_tensorflow Transformer implementation. This implementation lives in the mesh_tensorflow library and does not depend on Tensor2Tensor.
In the new Transformer implementation, the different kinds of layers in the transformer are subclasses of TransformerLayer. Model configurations contain lists of TransformerLayer instances (each containing its own hyperparameters). Users can add custom layers by adding new subclasses, without touching the core library. We don't have a growing list of global hyperparameters, and we don't have a giant switch statement of layer types.
Add a new model "mtf_transformer2" in Tensor2Tensor to use this new implementation. We will eventually deprecate the old "mtf_transformer" model.
Add a class mtf.VariableDType to encapsulate the different datatypes used for a variable: master_dtype, slice_dtype and activation_dtype, so as to avoid passing three arguments through many functions.
PiperOrigin-RevId: 224003970
0 commit comments