@@ -100,7 +100,6 @@ def decode(
100100 eos_token_id = None ,
101101 teacher_outputs = None ,
102102 vocab_size = None ,
103- tensor_parallel = 1 ,
104103 cg = False ,
105104 enable_timing = False ,
106105 streamer : Optional [TextStreamer ] = None
@@ -134,7 +133,6 @@ def decode(
134133 batch_size ,
135134 seqlen_og ,
136135 max_length ,
137- tensor_parallel = tensor_parallel ,
138136 )
139137 inference_params = model ._decoding_cache .inference_params
140138 inference_params .reset (max_length , batch_size )
@@ -186,8 +184,6 @@ def should_stop(current_token, inference_params):
186184 end = torch .cuda .Event (enable_timing = enable_timing )
187185
188186 if enable_timing :
189- if tensor_parallel > 1 :
190- torch .distributed .barrier ()
191187 start .record ()
192188 scores , sequences = [], [input_ids ]
193189 while not should_stop (sequences [- 1 ], inference_params ):
@@ -201,8 +197,6 @@ def should_stop(current_token, inference_params):
201197 streamer .end ()
202198 if enable_timing :
203199 end .record ()
204- if tensor_parallel > 1 :
205- torch .distributed .barrier ()
206200 torch .cuda .synchronize ()
207201 print (f"Prompt processing + decoding time: { (start .elapsed_time (end )):.0f} ms" )
208202 output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
@@ -232,22 +226,6 @@ def generate(
232226 return output if return_dict_in_generate else output .sequences
233227
234228
235- def allocate_inference_cache (
236- max_batch_size ,
237- max_seqlen ,
238- nheads ,
239- headdim ,
240- layers : Union [int , Sequence ],
241- device ,
242- dtype = torch .float16 ,
243- ):
244- assert dtype in [torch .float16 , torch .bfloat16 , torch .float32 ]
245- kv_cache_shape = (max_batch_size , max_seqlen , 2 , nheads , headdim )
246- if isinstance (layers , int ):
247- layers = range (layers )
248- return {i : torch .empty (kv_cache_shape , device = device , dtype = dtype ) for i in layers }
249-
250-
251229@dataclass
252230class DecodingCGCache :
253231 max_batch_size : int = 0
@@ -268,7 +246,6 @@ def update_graph_cache(
268246 seqlen_og ,
269247 max_seqlen ,
270248 decoding_seqlens = (1 ,),
271- tensor_parallel = 1 ,
272249 dtype = None ,
273250 n_warmups = 2 ,
274251):
@@ -289,23 +266,8 @@ def update_graph_cache(
289266 gc .collect ()
290267 cache .device , cache .dtype = device , dtype
291268 cache .max_batch_size , cache .max_seqlen = batch_size , max_seqlen
292- if hasattr (model , "allocate_inference_cache" ):
293- inf_cache = model .allocate_inference_cache (batch_size , max_seqlen , dtype )
294- else :
295- headdim = getattr (
296- model .config ,
297- "head_dim" ,
298- model .config .hidden_size // model .config .num_attention_heads ,
299- )
300- inf_cache = allocate_inference_cache (
301- batch_size ,
302- max_seqlen ,
303- model .config .num_attention_heads // tensor_parallel ,
304- headdim ,
305- model .config .num_hidden_layers ,
306- device ,
307- dtype ,
308- )
269+ assert hasattr (model , "allocate_inference_cache" ), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
270+ inf_cache = model .allocate_inference_cache (batch_size , max_seqlen , dtype )
309271 lengths_per_sample = torch .full ((batch_size ,), seqlen_og , dtype = torch .int32 , device = device )
310272 cache .inference_params = InferenceParams (
311273 max_seqlen = max_seqlen ,
0 commit comments