2121 get_hf_weight_map_and_path ,
2222 load_safetensor_weights )
2323from distributed .utils import Color as color
24- from distributed .utils import (GPUMemoryMonitor , TrackTime ,
24+ from distributed .utils import (GPUMemoryMonitor , TrackTime , CUDATrackTime ,
2525 bytes_to_readable , get_module_size ,
2626 get_num_params )
2727from distributed .verification_utils import find_cpu_tensors
@@ -256,7 +256,7 @@ def main():
256256 model .distribute (tp_mesh )
257257 # logger.info(f"Model: {model}")
258258
259- mbs = 4 # number of micro-batches
259+ mbs = 1 # number of micro-batches
260260 mb_size = 1 # micro-batch size
261261 batch_size = mbs * mb_size # total batch size
262262
@@ -308,11 +308,28 @@ def main():
308308 raise ValueError ("Found cpu tensors in stage" )
309309
310310 prompt = [
311- "What is the capital of France?" ,
312311 "What is snow?" ,
312+ ]
313+
314+ '''
315+ "What is the capital of France?",
313316 "What is your name?",
314317 "What is the capital of Japan?",
318+ "When is Christmas?",
319+ "Where does Santa Claus live?",
320+ "What is the capital of the United States?",
321+ "What is the capital of China?",
322+ "What is the capital of Russia?",
323+ "What is PyTorch?",
324+ "What is the capital of India?",
325+ "What is an LLM?",
326+ "What is the capital of Brazil?",
327+ "What is the capital of Mexico?",
328+ "What is the capital of Argentina?",
329+ "What is the capital of Canada?",
315330 ]
331+ '''
332+
316333 start_pos = 0
317334
318335 # encode the prompt
@@ -333,12 +350,57 @@ def main():
333350 schedule = ScheduleGPipe (stage , mbs )
334351 logger .info (f"Created schedule: { schedule } " )
335352
353+ # with CUDATrackTime() as timer:
354+ first_pp_rank = 0
355+ last_pp_rank = pp_degree - 1
356+
357+ x = torch .tensor ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,], device = device )
358+ x_recv = torch .zeros_like (x )
359+
360+ last_global_rank = world_size - 1
361+
336362 with torch .no_grad (): # .inference_mode():
363+ # for _ in range(1):
364+ # first
337365 if pp_rank == 0 :
338366 schedule .step (padded_sequence )
339- else :
367+ if rank == 0 :
368+ dist .recv (padded_sequence , src = last_global_rank , )
369+ logger .info (f"RECEIVING from { last_global_rank = } " )
370+ # elif rank == 1:
371+ # dist.recv(x_recv, src=last_global_rank-1, )
372+ # logger.info(f"RECEIVING from {last_global_rank=}")
373+ # logger.info(f"Received x_recv: {x_recv=}")
374+
375+ # elif tp_rank == 1:
376+ # dist.recv(padded_sequence, src=last_global_rank-1, )
377+ # last
378+ elif pp_rank == last_pp_rank :
340379 output = schedule .step ()
341-
380+ logger .info (f"SENDING back...from { pp_rank = } " )
381+ #if tp_rank == 0:
382+ if rank == world_size - 1 :
383+ dist .send (output , dst = 0 , )
384+ #dist.send(x, dst=1, )
385+
386+ # dist.send(x,dst = 1,)
387+ #elif tp_rank==1:
388+ # dist.send(output, dst=1, )
389+ # elif tp_rank == 1:
390+ # dist.send(output, dst=1, )
391+ # middle pp ranks
392+ else :
393+ schedule .step ()
394+
395+ if rank == 0 :
396+ logger .info (f"{ color .red } Success! Received output from { last_global_rank } { color .reset } " )
397+ logger .info (f"out of loop - Received output: { padded_sequence [4 :8 ]= } " ) # {padded_sequence[0, :prompt_lengths[0]+1]=}")
398+ if rank == 1 :
399+ logger .info (f"{ color .red } Success! Received output from { last_global_rank } { color .reset } " )
400+ logger .info (f"out of loop Received output: { x_recv = } " ) # {padded_sequence[0, :prompt_lengths[0]+1]=}")
401+
402+ #logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
403+ '''
342404 # Decoding
343405 if pp_rank == pp_degree - 1 and tp_rank == 0:
344406 decode_results = _batch_decode_next_tokens(
@@ -358,7 +420,8 @@ def main():
358420 logger.info(
359421 f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
360422 )
361-
423+ '''
424+ logger .info (f"{ color .green } Success{ color .white } - { color .blue } Rank { rank } has completed.{ color .reset } " )
362425 _cleanup ()
363426
364427
0 commit comments