2323import argparse
2424import copy
2525import logging
26+ import pkg_resources
2627import sys
27- import tensorflow as tf
2828
2929from . import TFCluster , util
3030from packaging import version
3131
3232
3333logger = logging .getLogger (__name__ )
34+ TF_VERSION = pkg_resources .get_distribution ('tensorflow' ).version
3435
3536
3637# TensorFlowOnSpark Params
@@ -370,7 +371,7 @@ def __init__(self, train_fn, tf_args, export_fn=None):
370371 self .train_fn = train_fn
371372 self .args = Namespace (tf_args )
372373
373- master_node = 'chief' if version .parse (tf . __version__ ) >= version .parse ("2.0.0" ) else None
374+ master_node = 'chief' if version .parse (TF_VERSION ) >= version .parse ("2.0.0" ) else None
374375 self ._setDefault (input_mapping = {},
375376 cluster_size = 1 ,
376377 num_ps = 0 ,
@@ -413,7 +414,7 @@ def _fit(self, dataset):
413414 cluster .shutdown (grace_secs = self .getGraceSecs ())
414415
415416 if self .export_fn :
416- if version .parse (tf . __version__ ) < version .parse ("2.0.0" ):
417+ if version .parse (TF_VERSION ) < version .parse ("2.0.0" ):
417418 # For TF1.x, run export function, if provided
418419 assert local_args .export_dir , "Export function requires --export_dir to be set"
419420 logging .info ("Exporting saved_model (via export_fn) to: {}" .format (local_args .export_dir ))
@@ -480,7 +481,7 @@ def _transform(self, dataset):
480481
481482 tf_args = self .args .argv if self .args .argv else local_args
482483
483- _run_model = _run_model_tf1 if version .parse (tf . __version__ ) < version .parse ("2.0.0" ) else _run_model_tf2
484+ _run_model = _run_model_tf1 if version .parse (TF_VERSION ) < version .parse ("2.0.0" ) else _run_model_tf2
484485 rdd_out = dataset .select (input_cols ).rdd .mapPartitions (lambda it : _run_model (it , local_args , tf_args ))
485486
486487 # convert to a DataFrame-friendly format
@@ -516,7 +517,7 @@ def _run_model_tf1(iterator, args, tf_args):
516517 output_tensor_names = [tensor for tensor , col in sorted (args .output_mapping .items ())]
517518
518519 # if using a signature_def_key, get input/output tensor info from the requested signature
519- if version .parse (tf . __version__ ) < version .parse ("2.0.0" ) and args .signature_def_key :
520+ if version .parse (TF_VERSION ) < version .parse ("2.0.0" ) and args .signature_def_key :
520521 assert args .export_dir , "Inferencing with signature_def_key requires --export_dir argument"
521522 logging .info ("===== loading meta_graph_def for tag_set ({0}) from saved_model: {1}" .format (args .tag_set , args .export_dir ))
522523 meta_graph_def = get_meta_graph_def (args .export_dir , args .tag_set )
@@ -534,6 +535,7 @@ def _run_model_tf1(iterator, args, tf_args):
534535 sess = global_sess
535536 else :
536537 # otherwise, create new session and load graph from disk
538+ import tensorflow as tf
537539 tf .reset_default_graph ()
538540 sess = tf .Session (graph = tf .get_default_graph ())
539541 if args .export_dir :
@@ -584,6 +586,8 @@ def _run_model_tf2(iterator, args, tf_args):
584586 """mapPartitions function (for TF2.x) to run single-node inferencing from a saved_model, using input/output mappings."""
585587 single_node_env (tf_args )
586588
589+ import tensorflow as tf
590+
587591 logger .info ("===== input_mapping: {}" .format (args .input_mapping ))
588592 logger .info ("===== output_mapping: {}" .format (args .output_mapping ))
589593 input_tensor_names = [tensor for col , tensor in sorted (args .input_mapping .items ())]
0 commit comments