99RETRY_INTERVAL_SEC = 5
1010
1111
12- def get_node_ip_address (leader_addr : str ) -> str :
12+ def get_node_fqdn (leader_addr : str ) -> str :
1313 # Assumes we're on a K8s cluster where the leader address is
1414 # <leader pod name>.<the rest of the FQDN>
1515 # e.g. if we're using a JobSet
1616 # Kinda of a dumb hack to get an externally addressable DNS name
17- node_ip_address = socket .gethostname () + "." + leader_addr .split ("." , 1 )[1 ]
18- return node_ip_address
17+ node_fqdn = socket .gethostname () + "." + leader_addr .split ("." , 1 )[1 ]
18+ return node_fqdn
1919
2020
2121def wait_for_dns (hostname : str , timeout : int = 300 , interval : int = 5 ):
@@ -48,7 +48,7 @@ def wait_for_cluster_nodes(
4848 bool: True if cluster reached expected size, False if timeout occurred
4949 """
5050 # Since we've subprocess.run for starting ray, need to connect in the cluster right here.
51- ray .init ()
51+ ray .init (log_to_driver = False )
5252 start_time = time .time ()
5353 while time .time () - start_time < timeout :
5454 try :
@@ -114,7 +114,7 @@ def wait_for_head_node_to_exit_process():
114114 # This will run in the subprocess spawned and will conveniently error out
115115 # when the head node is no longer reachable
116116 # The exit gets caught by the `wait_for_head_node_to_exit` function
117- ray .init ()
117+ ray .init (log_to_driver = False )
118118 while True :
119119 nodes = ray .nodes ()
120120 print (f"Able to get nodes list { len (nodes )} " , flush = True )
@@ -137,6 +137,20 @@ def start_leader(
137137 return False
138138
139139
140+ def is_ipv6_address (ip_address : str ) -> bool :
141+ try :
142+ socket .inet_pton (socket .AF_INET6 , ip_address )
143+ return True
144+ except socket .error :
145+ return False
146+
147+
148+ def format_ip_address (ip_address : str ) -> str :
149+ if is_ipv6_address (ip_address ):
150+ return f"[{ ip_address } ]"
151+ return ip_address
152+
153+
140154def start_worker (
141155 ray_port : int ,
142156 node_ip_address : str ,
@@ -147,17 +161,22 @@ def start_worker(
147161 # node ip address in this case is actually a DNS name for the pod
148162 start_time = time .time ()
149163 while time .time () - start_time < timeout :
164+ print (
165+ f"Starting ray worker with head address { format_ip_address (leader_addr )} :{ ray_port } and node ip address { node_ip_address } " ,
166+ flush = True ,
167+ )
150168 result = subprocess .run (
151169 [
152170 "ray" ,
153171 "start" ,
154172 "--address" ,
155- f"{ leader_addr } :{ ray_port } " ,
173+ f"{ format_ip_address ( leader_addr ) } :{ ray_port } " ,
156174 "--node-ip-address" ,
157175 node_ip_address ,
158176 ],
159177 capture_output = True ,
160178 )
179+ print (f"result: { result } " , flush = True )
161180 if result .returncode == 0 :
162181 print (
163182 f"Worker: Ray runtime started with head address { leader_addr } :{ ray_port } " ,
@@ -175,6 +194,13 @@ def start_worker(
175194 return False
176195
177196
197+ def get_node_ip_address (node_fqdn : str , timeout : int = 300 ) -> str :
198+ node_ip_info = wait_for_dns (node_fqdn , timeout = timeout )
199+ if node_ip_info is None :
200+ raise RuntimeError (f"Timeout waiting for DNS resolution of { node_fqdn } " )
201+ return node_ip_info [0 ][4 ][0 ]
202+
203+
178204def init_ray (
179205 leader_addr : str ,
180206 leader_port : int ,
@@ -193,21 +219,30 @@ def init_ray(
193219 node_ip_address: IP address of the current node. If None, will be automatically detected
194220 timeout: Maximum time to wait for cluster to reach expected size
195221 """
196- node_ip_address = get_node_ip_address (leader_addr )
222+ import os
223+
224+ # export environment variable to disable ray logging
225+ os .environ ["NCCL_DEBUG" ] = "INFO"
226+
227+ # Get FQDN of the current node
228+ node_fqdn = get_node_fqdn (leader_addr )
229+ print (f"node fqdn: { node_fqdn } " , flush = True )
197230
198231 print (f"Waiting for head node DNS ({ leader_addr } ) to be resolvable..." , flush = True )
199- head_ip_info = wait_for_dns (leader_addr , timeout = timeout )
200- if head_ip_info is None :
201- raise RuntimeError (f"Timeout waiting for DNS resolution of { leader_addr } " )
232+ leader_ip_address = get_node_ip_address (leader_addr , timeout = timeout )
233+ print (f"leader ip: { leader_ip_address } " , flush = True )
202234
203235 if is_leader :
204- if not start_leader (leader_port , node_ip_address ):
236+ if not start_leader (leader_port , leader_ip_address ):
205237 raise RuntimeError ("Failed to start Ray leader node" )
206238 else :
207- if not start_worker (leader_port , node_ip_address , leader_addr , timeout ):
239+ print (f"Waiting for worker node DNS ({ node_fqdn } ) to be resolvable..." , flush = True )
240+ worker_ip_address = get_node_ip_address (node_fqdn , timeout = timeout )
241+ print (f"worker ip: { worker_ip_address } " , flush = True )
242+ if not start_worker (leader_port , worker_ip_address , leader_ip_address , timeout ):
208243 raise RuntimeError ("Failed to start Ray worker node" )
209244 print (
210- f"Successfully initialized Ray { 'head' if is_leader else 'worker' } node at { node_ip_address } " ,
245+ f"Successfully initialized Ray { 'head' if is_leader else 'worker' } node at { leader_ip_address if is_leader else worker_ip_address } " ,
211246 flush = True ,
212247 )
213248
@@ -231,5 +266,6 @@ def main(mode: str):
231266if __name__ == "__main__" :
232267 parser = argparse .ArgumentParser ()
233268 parser .add_argument ("--mode" , choices = ["wait_for_head_node_to_exit" ], required = True )
269+
234270 args = parser .parse_args ()
235271 main (args .mode )
0 commit comments