3030device_mesh = None
3131
3232
33- class Transformer (nn .Module ):
34- def __init__ (self , config : TransformerArgs ) -> None :
33+ class TransformerStage (nn .Module ):
34+ def __init__ (self , config : TransformerArgs , stage_idx : int , n_stages : int ) -> None :
3535 super ().__init__ ()
3636 self .config = config
37+ self .stage_idx = stage_idx
38+ self .n_stages = n_stages
39+ self .layers_per_stage = config .n_layers // n_stages
3740
3841 # Get device mesh
3942 global device_mesh
4043 if device_mesh is None :
4144 device_mesh = _mesh_resources .get_current_mesh ()
4245
43- tok_embeddings = nn .Embedding (config .vocab_size , config .dim )
44- self .tok_embeddings = parallelize_module (
45- tok_embeddings ,
46- device_mesh ,
47- RowwiseParallel (input_layouts = Replicate ()),
48- )
49- self .layers = nn .ModuleList (
50- TransformerBlock (config ) for _ in range (config .n_layers )
51- )
52- self .norm = RMSNorm (config .dim , eps = config .norm_eps )
53- self .output = nn .Linear (config .dim , config .vocab_size , bias = False )
46+ if stage_idx == 0 :
47+ tok_embeddings = nn .Embedding (config .vocab_size , config .dim )
48+ self .tok_embeddings = parallelize_module (
49+ tok_embeddings ,
50+ device_mesh ,
51+ RowwiseParallel (input_layouts = Replicate ()),
52+ )
53+
54+ # Use ModuleDict so that each layer can be assigned its layer ID in the original model
55+ self .layers = nn .ModuleDict ()
56+ for layer_id in range (self .layers_per_stage * stage_idx , self .layers_per_stage * (stage_idx + 1 )):
57+ self .layers [str (layer_id )] = TransformerBlock (config )
58+
59+ if stage_idx == n_stages - 1 :
60+ self .norm = RMSNorm (config .dim , eps = config .norm_eps )
61+ self .output = nn .Linear (config .dim , config .vocab_size , bias = False )
5462
5563 # self.freqs_cis: Optional[Tensor] = None
5664 # self.mask_cache: Optional[Tensor] = None
@@ -67,7 +75,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
6775 max_seq_length = find_multiple (max_seq_length , 8 )
6876 self .max_seq_length = max_seq_length
6977 self .max_batch_size = max_batch_size
70- for b in self .layers :
78+ for b in self .layers . values () :
7179 b .attention .kv_cache = KVCache (
7280 max_batch_size , max_seq_length , self .config .n_local_heads , head_dim
7381 )
@@ -84,19 +92,26 @@ def setup_caches(self, max_batch_size, max_seq_length):
8492 )
8593 self .register_buffer ("causal_mask" , causal_mask , persistent = True )
8694
87- def forward (self , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
95+ def forward (self , x : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
8896 assert self .freqs_cis is not None , "Caches must be initialized first"
97+ if input_pos is None :
98+ input_pos = torch .arange (x .shape [1 ], device = x .device , dtype = torch .long )
8999 mask = self .causal_mask [None , None , input_pos ]
90100 freqs_cis = self .freqs_cis [input_pos ]
91- x : DTensor = self .tok_embeddings (idx )
92- # TODO: sequence parallelize this
93101
94- for _ , layer in enumerate (self .layers ):
102+ if self .stage_idx == 0 :
103+ x : DTensor = self .tok_embeddings (x )
104+ # TODO: sequence parallelize this
105+
106+ for _ , layer in self .layers .items ():
95107 x = layer (x , input_pos , freqs_cis , mask )
96- x = self .norm (x )
97- logits = self .output (x )
98- # print(f"logits shape: {logits.shape}")
99- return logits
108+
109+ if self .stage_idx == self .n_stages - 1 :
110+ x = self .norm (x )
111+ x = self .output (x )
112+
113+ # print(f"stage output shape: {x.shape}")
114+ return x
100115
101116 @classmethod
102117 def from_name (cls , name : str ):
0 commit comments