1616
1717import getpass
1818import logging
19+
20+ from packaging import version
1921from six .moves .queue import Empty
2022from . import marker
2123
@@ -61,8 +63,86 @@ def hdfs_path(ctx, path):
6163
6264
6365def start_cluster_server (ctx , num_gpus = 1 , rdma = False ):
64- """*DEPRECATED*. Use higher-level APIs like `tf.keras` or `tf.estimator`"""
65- raise Exception ("DEPRECATED: Use higher-level APIs like `tf.keras` or `tf.estimator`" )
66+ """Function that wraps the creation of TensorFlow ``tf.train.Server`` for a node in a distributed TensorFlow cluster.
67+
68+ This is intended to be invoked from within the TF ``map_fun``, replacing explicit code to instantiate ``tf.train.ClusterSpec``
69+ and ``tf.train.Server`` objects.
70+
71+ DEPRECATED for TensorFlow 2.x+
72+
73+ Args:
74+ :ctx: TFNodeContext containing the metadata specific to this node in the cluster.
75+ :num_gpu: number of GPUs desired
76+ :rdma: boolean indicating if RDMA 'iverbs' should be used for cluster communications.
77+
78+ Returns:
79+ A tuple of (cluster_spec, server)
80+ """
81+ import os
82+ import tensorflow as tf
83+ import time
84+ from . import gpu_info
85+
86+ if version .parse (tf .__version__ ) >= version .parse ("2.0.0" ):
87+ raise Exception ("DEPRECATED: Use higher-level APIs like `tf.keras` or `tf.estimator`" )
88+
89+ logging .info ("{0}: ======== {1}:{2} ========" .format (ctx .worker_num , ctx .job_name , ctx .task_index ))
90+ cluster_spec = ctx .cluster_spec
91+ logging .info ("{0}: Cluster spec: {1}" .format (ctx .worker_num , cluster_spec ))
92+
93+ if tf .test .is_built_with_cuda () and num_gpus > 0 :
94+ # compute my index relative to other nodes placed on the same host (for GPU allocation)
95+ my_addr = cluster_spec [ctx .job_name ][ctx .task_index ]
96+ my_host = my_addr .split (':' )[0 ]
97+ flattened = [v for sublist in cluster_spec .values () for v in sublist ]
98+ local_peers = [p for p in flattened if p .startswith (my_host )]
99+ my_index = local_peers .index (my_addr )
100+
101+ # GPU
102+ gpu_initialized = False
103+ retries = 3
104+ while not gpu_initialized and retries > 0 :
105+ try :
106+ # override PS jobs to only reserve one GPU
107+ if ctx .job_name == 'ps' :
108+ num_gpus = 0
109+
110+ # Find a free gpu(s) to use
111+ gpus_to_use = gpu_info .get_gpus (num_gpus , my_index )
112+ gpu_prompt = "GPU" if num_gpus == 1 else "GPUs"
113+ logging .info ("{0}: Using {1}: {2}" .format (ctx .worker_num , gpu_prompt , gpus_to_use ))
114+
115+ # Set GPU device to use for TensorFlow
116+ os .environ ['CUDA_VISIBLE_DEVICES' ] = gpus_to_use
117+
118+ # Create a cluster from the parameter server and worker hosts.
119+ cluster = tf .train .ClusterSpec (cluster_spec )
120+
121+ # Create and start a server for the local task.
122+ if rdma :
123+ server = tf .train .Server (cluster , ctx .job_name , ctx .task_index , protocol = "grpc+verbs" )
124+ else :
125+ server = tf .train .Server (cluster , ctx .job_name , ctx .task_index )
126+ gpu_initialized = True
127+ except Exception as e :
128+ print (e )
129+ logging .error ("{0}: Failed to allocate GPU, trying again..." .format (ctx .worker_num ))
130+ retries -= 1
131+ time .sleep (10 )
132+ if not gpu_initialized :
133+ raise Exception ("Failed to allocate GPU" )
134+ else :
135+ # CPU
136+ os .environ ['CUDA_VISIBLE_DEVICES' ] = ''
137+ logging .info ("{0}: Using CPU" .format (ctx .worker_num ))
138+
139+ # Create a cluster from the parameter server and worker hosts.
140+ cluster = tf .train .ClusterSpec (cluster_spec )
141+
142+ # Create and start a server for the local task.
143+ server = tf .train .Server (cluster , ctx .job_name , ctx .task_index )
144+
145+ return (cluster , server )
66146
67147
68148def next_batch (mgr , batch_size , qname = 'input' ):
@@ -71,8 +151,55 @@ def next_batch(mgr, batch_size, qname='input'):
71151
72152
73153def export_saved_model (sess , export_dir , tag_set , signatures ):
74- """*DEPRECATED*. Use TF provided APIs instead."""
75- raise Exception ("DEPRECATED: Use TF provided APIs instead." )
154+ """Convenience function to export a saved_model using provided arguments
155+
156+ The caller specifies the saved_model signatures in a simplified python dictionary form, as follows::
157+
158+ signatures = {
159+ 'signature_def_key': {
160+ 'inputs': { 'input_tensor_alias': input_tensor_name },
161+ 'outputs': { 'output_tensor_alias': output_tensor_name },
162+ 'method_name': 'method'
163+ }
164+ }
165+
166+ And this function will generate the `signature_def_map` and export the saved_model.
167+
168+ DEPRECATED for TensorFlow 2.x+.
169+
170+ Args:
171+ :sess: a tf.Session instance
172+ :export_dir: path to save exported saved_model
173+ :tag_set: string tag_set to identify the exported graph
174+ :signatures: simplified dictionary representation of a TensorFlow signature_def_map
175+
176+ Returns:
177+ A saved_model exported to disk at ``export_dir``.
178+ """
179+ import tensorflow as tf
180+
181+ if version .parse (tf .__version__ ) >= version .parse ("2.0.0" ):
182+ raise Exception ("DEPRECATED: Use TF provided APIs instead." )
183+
184+ g = sess .graph
185+ g ._unsafe_unfinalize () # https://github.com/tensorflow/serving/issues/363
186+ builder = tf .saved_model .builder .SavedModelBuilder (export_dir )
187+
188+ logging .info ("===== signatures: {}" .format (signatures ))
189+ signature_def_map = {}
190+ for key , sig in signatures .items ():
191+ signature_def_map [key ] = tf .saved_model .signature_def_utils .build_signature_def (
192+ inputs = {name : tf .saved_model .utils .build_tensor_info (tensor ) for name , tensor in sig ['inputs' ].items ()},
193+ outputs = {name : tf .saved_model .utils .build_tensor_info (tensor ) for name , tensor in sig ['outputs' ].items ()},
194+ method_name = sig ['method_name' ] if 'method_name' in sig else key )
195+ logging .info ("===== signature_def_map: {}" .format (signature_def_map ))
196+ builder .add_meta_graph_and_variables (
197+ sess ,
198+ tag_set .split (',' ),
199+ signature_def_map = signature_def_map ,
200+ clear_devices = True )
201+ g .finalize ()
202+ builder .save ()
76203
77204
78205def batch_results (mgr , results , qname = 'output' ):
0 commit comments