@@ -39,7 +39,7 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
3939def  _launch_distributed_inference (
4040    model_name : str , builder_args : BuilderArgs , tokenizer_args : TokenizerArgs 
4141) ->  tuple [List ]:
42-     # create programmatic elastic launch  
42+     # launch distributed inference worker, each worker gets a pipe to communicate with the main process  
4343    logger .info ("Launching distributed inference ..." )
4444
4545    num_processes_per_node  =  builder_args .pp  *  builder_args .tp 
@@ -50,17 +50,25 @@ def _launch_distributed_inference(
5050
5151    pipes  =  []
5252    procs  =  []
53-     for  rank  in  range (num_processes_per_node ):
54-         server_pipe , client_pipe  =  mp .Pipe (duplex = True )
55-         pipes .append (server_pipe )
56-         proc  =  mp .Process (
57-             target = partial (_setup_env , num_processes_per_node , rank , main ),
58-             args = (model_name , builder_args , tokenizer_args , client_pipe ),
59-         )
60-         proc .start ()
53+     try :
54+         for  rank  in  range (num_processes_per_node ):
55+             server_pipe , client_pipe  =  mp .Pipe (duplex = True )
56+             pipes .append (server_pipe )
57+             procs .append (
58+                 mp .Process (
59+                     target = partial (_setup_env , num_processes_per_node , rank , main ),
60+                     args = (model_name , builder_args , tokenizer_args , client_pipe ),
61+                 )
62+             )
63+             procs [- 1 ].start ()
6164
62-     for  pipe  in  pipes :
63-         response  =  pipe .recv ()
65+         for  pipe  in  pipes :
66+             assert  pipe .recv () ==  "ready" , "Starting the worker failed" 
67+     except  Exception  as  e :
68+         logger .error (f"Error during distributed inference: { str (e )}  " )
69+         for  p  in  procs :
70+             p .kill ()
71+         raise  e 
6472
6573    logger .info (
6674        f"Done launching distributed inference on { num_processes_per_node }   GPUs." 
@@ -105,11 +113,13 @@ def __init__(
105113        self .loop  =  loop 
106114
107115    def  schedule_request (self , req : Request ):
116+         # add request to queue and create deque and async event for response 
108117        self .req_to_states [req .request_id ] =  asyncio .Event ()
109118        self .req_to_results [req .request_id ] =  deque ()
110119        self .request_queue .put (req )
111120
112121    def  process_requests_loop (self ):
122+         # Continuously process requests (one at a time for now), results are routed into the requests deque 
113123        while  True :
114124            req  =  self .request_queue .get ()
115125            if  req  ==  "stop" :
@@ -127,6 +137,7 @@ def process_requests_loop(self):
127137                running  &=  not  outputs [0 ].is_finished 
128138
129139    async  def  wait_for_request (self , req : Request ) ->  Output :
140+         # Wait for request to deliver result, uses event to trigger and reads from left side of deque 
130141        is_finished  =  False 
131142        while  not  is_finished :
132143            await  self .req_to_states [req .request_id ].wait ()
@@ -138,6 +149,7 @@ async def wait_for_request(self, req: Request) -> Output:
138149        del  self .req_to_results [req .request_id ]
139150
140151    def  step (self ) ->  List [Output ]:
152+         # Make a prefill or decoding step and receive results 
141153        responses  =  []
142154        # TODO: Implement a scheduler to handle the requests 
143155        if  len (self .in_flight_requests ) >  0 :
@@ -166,6 +178,7 @@ def step(self) -> List[Output]:
166178            text , token_ids  =  v 
167179            outputs .append (
168180                Output (
181+                     # TODO: Look for tokenizer.eos_id as well 
169182                    is_finished = self .current_step  >=  self .generator_args .max_new_tokens ,
170183                    text = text ,
171184                    token = token_ids ,
@@ -218,6 +231,7 @@ def __init__(
218231        atexit .register (self .shutdown )
219232
220233    def  shutdown (self ):
234+         # Stop all processes and threads 
221235        self .scheduler .request_queue .put ("stop" )
222236        self .scheduler_thread .join ()
223237
@@ -227,6 +241,7 @@ def shutdown(self):
227241            p .kill ()
228242
229243    def  generate (self , text ):
244+         # Function to generate text from prompt 
230245        req  =  Request .new_request (text )
231246        self .scheduler .schedule_request (req )
232247
0 commit comments