Skip to content

Commit dcc9309

Browse files
committed
Remove tensor parallel in generation
1 parent cc01e2b commit dcc9309

File tree

1 file changed

+2
-40
lines changed

1 file changed

+2
-40
lines changed

mamba_ssm/utils/generation.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def decode(
100100
eos_token_id=None,
101101
teacher_outputs=None,
102102
vocab_size=None,
103-
tensor_parallel=1,
104103
cg=False,
105104
enable_timing=False,
106105
streamer: Optional[TextStreamer] = None
@@ -134,7 +133,6 @@ def decode(
134133
batch_size,
135134
seqlen_og,
136135
max_length,
137-
tensor_parallel=tensor_parallel,
138136
)
139137
inference_params = model._decoding_cache.inference_params
140138
inference_params.reset(max_length, batch_size)
@@ -186,8 +184,6 @@ def should_stop(current_token, inference_params):
186184
end = torch.cuda.Event(enable_timing=enable_timing)
187185

188186
if enable_timing:
189-
if tensor_parallel > 1:
190-
torch.distributed.barrier()
191187
start.record()
192188
scores, sequences = [], [input_ids]
193189
while not should_stop(sequences[-1], inference_params):
@@ -201,8 +197,6 @@ def should_stop(current_token, inference_params):
201197
streamer.end()
202198
if enable_timing:
203199
end.record()
204-
if tensor_parallel > 1:
205-
torch.distributed.barrier()
206200
torch.cuda.synchronize()
207201
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
208202
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
@@ -232,22 +226,6 @@ def generate(
232226
return output if return_dict_in_generate else output.sequences
233227

234228

235-
def allocate_inference_cache(
236-
max_batch_size,
237-
max_seqlen,
238-
nheads,
239-
headdim,
240-
layers: Union[int, Sequence],
241-
device,
242-
dtype=torch.float16,
243-
):
244-
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
245-
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
246-
if isinstance(layers, int):
247-
layers = range(layers)
248-
return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
249-
250-
251229
@dataclass
252230
class DecodingCGCache:
253231
max_batch_size: int = 0
@@ -268,7 +246,6 @@ def update_graph_cache(
268246
seqlen_og,
269247
max_seqlen,
270248
decoding_seqlens=(1,),
271-
tensor_parallel=1,
272249
dtype=None,
273250
n_warmups=2,
274251
):
@@ -289,23 +266,8 @@ def update_graph_cache(
289266
gc.collect()
290267
cache.device, cache.dtype = device, dtype
291268
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
292-
if hasattr(model, "allocate_inference_cache"):
293-
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
294-
else:
295-
headdim = getattr(
296-
model.config,
297-
"head_dim",
298-
model.config.hidden_size // model.config.num_attention_heads,
299-
)
300-
inf_cache = allocate_inference_cache(
301-
batch_size,
302-
max_seqlen,
303-
model.config.num_attention_heads // tensor_parallel,
304-
headdim,
305-
model.config.num_hidden_layers,
306-
device,
307-
dtype,
308-
)
269+
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
270+
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
309271
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
310272
cache.inference_params = InferenceParams(
311273
max_seqlen=max_seqlen,

0 commit comments

Comments
 (0)