Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ef0a03d

Browse files
committed
decoding start
1 parent 0f28976 commit ef0a03d

File tree

2 files changed

+70
-7
lines changed

2 files changed

+70
-7
lines changed

dist_run.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
get_hf_weight_map_and_path,
2222
load_safetensor_weights)
2323
from distributed.utils import Color as color
24-
from distributed.utils import (GPUMemoryMonitor, TrackTime,
24+
from distributed.utils import (GPUMemoryMonitor, TrackTime, CUDATrackTime,
2525
bytes_to_readable, get_module_size,
2626
get_num_params)
2727
from distributed.verification_utils import find_cpu_tensors
@@ -256,7 +256,7 @@ def main():
256256
model.distribute(tp_mesh)
257257
# logger.info(f"Model: {model}")
258258

259-
mbs = 4 # number of micro-batches
259+
mbs = 1 # number of micro-batches
260260
mb_size = 1 # micro-batch size
261261
batch_size = mbs * mb_size # total batch size
262262

@@ -308,11 +308,28 @@ def main():
308308
raise ValueError("Found cpu tensors in stage")
309309

310310
prompt = [
311-
"What is the capital of France?",
312311
"What is snow?",
312+
]
313+
314+
'''
315+
"What is the capital of France?",
313316
"What is your name?",
314317
"What is the capital of Japan?",
318+
"When is Christmas?",
319+
"Where does Santa Claus live?",
320+
"What is the capital of the United States?",
321+
"What is the capital of China?",
322+
"What is the capital of Russia?",
323+
"What is PyTorch?",
324+
"What is the capital of India?",
325+
"What is an LLM?",
326+
"What is the capital of Brazil?",
327+
"What is the capital of Mexico?",
328+
"What is the capital of Argentina?",
329+
"What is the capital of Canada?",
315330
]
331+
'''
332+
316333
start_pos = 0
317334

318335
# encode the prompt
@@ -333,12 +350,57 @@ def main():
333350
schedule = ScheduleGPipe(stage, mbs)
334351
logger.info(f"Created schedule: {schedule}")
335352

353+
# with CUDATrackTime() as timer:
354+
first_pp_rank = 0
355+
last_pp_rank = pp_degree - 1
356+
357+
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,], device=device)
358+
x_recv = torch.zeros_like(x)
359+
360+
last_global_rank = world_size - 1
361+
336362
with torch.no_grad(): # .inference_mode():
363+
# for _ in range(1):
364+
# first
337365
if pp_rank == 0:
338366
schedule.step(padded_sequence)
339-
else:
367+
if rank == 0:
368+
dist.recv(padded_sequence, src=last_global_rank, )
369+
logger.info(f"RECEIVING from {last_global_rank=}")
370+
# elif rank == 1:
371+
# dist.recv(x_recv, src=last_global_rank-1, )
372+
# logger.info(f"RECEIVING from {last_global_rank=}")
373+
# logger.info(f"Received x_recv: {x_recv=}")
374+
375+
# elif tp_rank == 1:
376+
# dist.recv(padded_sequence, src=last_global_rank-1, )
377+
# last
378+
elif pp_rank == last_pp_rank:
340379
output = schedule.step()
341-
380+
logger.info(f"SENDING back...from {pp_rank=}")
381+
#if tp_rank == 0:
382+
if rank == world_size-1:
383+
dist.send(output, dst=0, )
384+
#dist.send(x, dst=1, )
385+
386+
# dist.send(x,dst = 1,)
387+
#elif tp_rank==1:
388+
# dist.send(output, dst=1, )
389+
# elif tp_rank == 1:
390+
# dist.send(output, dst=1, )
391+
# middle pp ranks
392+
else:
393+
schedule.step()
394+
395+
if rank==0:
396+
logger.info(f"{color.red} Success! Received output from {last_global_rank} {color.reset}")
397+
logger.info(f"out of loop - Received output: {padded_sequence[4:8]=}") # {padded_sequence[0, :prompt_lengths[0]+1]=}")
398+
if rank ==1:
399+
logger.info(f"{color.red} Success! Received output from {last_global_rank} {color.reset}")
400+
logger.info(f"out of loop Received output: {x_recv=}") # {padded_sequence[0, :prompt_lengths[0]+1]=}")
401+
402+
#logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
403+
'''
342404
# Decoding
343405
if pp_rank == pp_degree - 1 and tp_rank == 0:
344406
decode_results = _batch_decode_next_tokens(
@@ -358,7 +420,8 @@ def main():
358420
logger.info(
359421
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
360422
)
361-
423+
'''
424+
logger.info(f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}")
362425
_cleanup()
363426

364427

run_dist.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export CUDA_VISIBLE_DEVICES=4,5,6,7
1+
#export CUDA_VISIBLE_DEVICES=4,5,6,7
22
PORT=${1:-29501}
33
NGPU=${NGPU:-"4"}
44
LOG_RANK=${LOG_RANK:-0,1,2,3}

0 commit comments

Comments
 (0)