@@ -111,3 +111,83 @@ def test_generation_varlen():
111111 out_varlen = torch .cat (scores , dim = 1 )
112112 print (f"Max diff: { (out_varlen - out_ref ).abs ().max ()} " )
113113 assert (out_varlen - out_ref ).abs ().max () < 2 * (out_loop - out_ref ).abs ().max ()
114+
115+ def test_generation_varlen_with_padding ():
116+ seqlens = [170 , 65 , 100 ]
117+ non_padded_seqlen = sum (seqlens )
118+ padded_seqlen = 512
119+ seqlens .append (padded_seqlen - non_padded_seqlen )
120+ genlen = 20
121+ total_seqlen = sum (seqlens )
122+ assert total_seqlen == padded_seqlen
123+ device = "cuda"
124+ dtype = torch .float16
125+
126+ config = MambaConfig (
127+ d_model = 1024 ,
128+ n_layer = 4 ,
129+ vocab_size = 50277 ,
130+ ssm_cfg = dict (layer = "Mamba2" ),
131+ rms_norm = True ,
132+ residual_in_fp32 = True ,
133+ fused_add_norm = True ,
134+ pad_vocab_size_multiple = 16 ,
135+ )
136+ torch .manual_seed (2357 )
137+ model = MambaLMHeadModel (config , device = device , dtype = dtype )
138+ xs = [torch .randint (0 , 1000 , (1 , seqlen ), device = device , dtype = torch .long ) for seqlen in seqlens ]
139+
140+ # Reference 1: Forward pass with seq_idx
141+ x = torch .cat (xs [:- 1 ], dim = 1 )
142+ seq_idx = torch .cat ([torch .full ((ids .shape [1 ],), i , dtype = torch .int32 , device = device )
143+ for i , ids in enumerate (xs [:- 1 ])], dim = 0 ).unsqueeze (0 )
144+ cu_seqlens = F .pad (torch .tensor (seqlens [:- 1 ], device = device , dtype = torch .int32 ).cumsum (dim = 0 ), (1 , 0 ))
145+
146+ out_ref = model (x , seq_idx = seq_idx ).logits
147+ # Only take the last @genlen logits of each sequence
148+ out_ref = torch .cat ([out_ref [:, cu_seqlens [i + 1 ] - genlen - 1 :cu_seqlens [i + 1 ] - 1 ]
149+ for i in range (len (seqlens ) - 1 )], dim = 0 )
150+
151+ # Reference 2: Generate the last @genlen tokens of each sequence in a for loop
152+ out_loop = []
153+ for input_ids in xs [:- 1 ]:
154+ out = model .generate (
155+ input_ids = input_ids [:, :- genlen ], max_length = input_ids .shape [1 ], output_scores = True ,
156+ return_dict_in_generate = True , cg = True , teacher_outputs = input_ids ,
157+ ).scores
158+ out_loop .append (torch .stack (out , dim = 1 ))
159+ out_loop = torch .cat (out_loop , dim = 0 )
160+ print (f"Max diff between ref1 and ref2: { (out_loop - out_ref ).abs ().max ()} " )
161+
162+ # Varlen generation
163+ input_ids = torch .cat ([ids [:, :- genlen ] for ids in xs ], dim = 1 )
164+ prompt_seqlens = [seqlen - genlen for seqlen in seqlens ]
165+ cu_seqlens = F .pad (torch .tensor (prompt_seqlens , device = device , dtype = torch .int32 ).cumsum (dim = 0 ), (1 , 0 ))
166+ seq_idx = torch .cat ([torch .full ((seqlen ,), i , dtype = torch .int32 , device = device )
167+ for i , seqlen in enumerate (prompt_seqlens )], dim = 0 ).unsqueeze (0 )
168+ inference_params = InferenceParams (max_seqlen = 2048 , max_batch_size = len (seqlens ))
169+
170+ # Account for padding
171+ offset = genlen * len (seqlens )
172+ seq_idx [non_padded_seqlen - offset : padded_seqlen - offset ] = - 1
173+ cu_seqlens [- 1 ] = cu_seqlens [- 2 ]
174+
175+ scores , sequences = [], []
176+ # Both seq_idx and cu_seqlens must be passed in for varlen generation
177+ logits = model (input_ids , inference_params = inference_params , seq_idx = seq_idx , cu_seqlens = cu_seqlens ).logits
178+ logits = rearrange (logits [0 , cu_seqlens [1 :] - 1 ], "b d -> b 1 d" )
179+ scores .append (logits )
180+ # In practice we should sample. In this case we take from the teacher_output for testing
181+ sampled_tokens = rearrange (torch .stack ([ids [0 , - genlen ] for ids in xs ], dim = 0 ), "b -> b 1" )
182+ sequences .append (sampled_tokens )
183+ for i in range (1 , genlen ):
184+ inference_params .seqlen_offset += 1
185+ logits = model (sampled_tokens , inference_params = inference_params , num_last_tokens = 1 ).logits
186+ scores .append (logits )
187+ # In practice we should sample. In this case we take from the teacher_output for testing
188+ sampled_tokens = rearrange (torch .stack ([ids [0 , - genlen + i ] for ids in xs ], dim = 0 ), "b -> b 1" )
189+ sequences .append (sampled_tokens )
190+ out_varlen = torch .cat (scores , dim = 1 )
191+
192+ print (f"Max diff: { (out_varlen [:- 1 ] - out_ref ).abs ().max ()} " )
193+ assert (out_varlen [:- 1 ] - out_ref ).abs ().max () < 2 * (out_loop - out_ref ).abs ().max ()
0 commit comments