Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 18f356e

Browse files
nshazeerCopybara-Service
authored andcommitted
Starting work on a new mesh_tensorflow Transformer implementation. This implementation lives in the mesh_tensorflow library and does not depend on Tensor2Tensor.
In the new Transformer implementation, the different kinds of layers in the transformer are subclasses of TransformerLayer. Model configurations contain lists of TransformerLayer instances (each containing its own hyperparameters). Users can add custom layers by adding new subclasses, without touching the core library. We don't have a growing list of global hyperparameters, and we don't have a giant switch statement of layer types. Add a new model "mtf_transformer2" in Tensor2Tensor to use this new implementation. We will eventually deprecate the old "mtf_transformer" model. Add a class mtf.VariableDType to encapsulate the different datatypes used for a variable: master_dtype, slice_dtype and activation_dtype, so as to avoid passing three arguments through many functions. PiperOrigin-RevId: 224003970
1 parent 217b90c commit 18f356e

File tree

5 files changed

+1694
-103
lines changed

5 files changed

+1694
-103
lines changed

0 commit comments

Comments
 (0)