Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit f4133f0

Browse files
Allow reload after converter.save() to avoid engine rebuilding
1 parent f9f566b commit f4133f0

File tree

1 file changed

+37
-18
lines changed

1 file changed

+37
-18
lines changed

tftrt/examples/benchmark_runner.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222

2323
from tensorflow.python.compiler.tensorrt import trt_convert as trt
2424

25+
from tensorflow.python.saved_model import signature_constants
26+
from tensorflow.python.saved_model import tag_constants
27+
2528
__all__ = ["BaseBenchmarkRunner"]
2629

2730

@@ -100,28 +103,37 @@ def _get_graph_func(self):
100103
returns: TF function that is ready to run for inference
101104
"""
102105

103-
if not self._args.use_tftrt:
106+
def load_model_from_disk(
107+
path,
108+
tags=[tag_constants.SERVING],
109+
signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
110+
):
111+
saved_model_loaded = tf.saved_model.load(export_dir=path, tags=tags)
104112

105-
with timed_section("Loading TensorFlow native model"):
113+
graph_func = saved_model_loaded.signatures[signature_key]
106114

107-
saved_model_loaded = tf.saved_model.load(
108-
export_dir=self._args.input_saved_model_dir,
109-
tags=self._args.model_tag.split(",")
110-
)
115+
# from tensorflow.python.framework import convert_to_constants
116+
# graph_func = convert_to_constants.convert_variables_to_constants_v2(
117+
# graph_func
118+
# )
119+
120+
# Known TF Issue: https://github.com/tensorflow/tensorflow/issues/37615#issuecomment-767804930
121+
# it looks like if the original trackable object is released by
122+
# the Python garbage collector once it goes out of scope, and
123+
# the signature returned by the function does not maintain a
124+
# back-reference to the original loaded object.
125+
graph_func._backref_to_saved_model = saved_model_loaded
111126

112-
graph_func = saved_model_loaded.signatures[
113-
self._args.input_signature_key]
114-
# from tensorflow.python.framework import convert_to_constants
115-
# graph_func = convert_to_constants.convert_variables_to_constants_v2(
116-
# graph_func
117-
# )
127+
return graph_func
118128

119-
# Known TF Issue: https://github.com/tensorflow/tensorflow/issues/37615#issuecomment-767804930
120-
# it looks like if the original trackable object is released by
121-
# the Python garbage collector once it goes out of scope, and
122-
# the signature returned by the function does not maintain a
123-
# back-reference to the original loaded object.
124-
graph_func._backref_to_saved_model = saved_model_loaded
129+
if not self._args.use_tftrt:
130+
131+
with timed_section("Loading TensorFlow native model"):
132+
graph_func = load_model_from_disk(
133+
path=self._args.input_saved_model_dir,
134+
tags=self._args.model_tag.split(","),
135+
signature_key=self._args.input_signature_key
136+
)
125137

126138
else:
127139

@@ -231,6 +243,13 @@ def engine_build_input_fn(num_batches, model_phase):
231243
f"Converted graph saved to "
232244
f"`{self._args.output_saved_model_dir}`"
233245
)
246+
# Engine cache is cleared while saving, we have to reload.
247+
# Failing to do so, would force TF-TRT to rebuild
248+
del converter
249+
del graph_func
250+
graph_func = load_model_from_disk(
251+
self._args.output_saved_model_dir
252+
)
234253

235254
if isinstance(graph_func.structured_outputs, (tuple, list)):
236255
savedmodel_outputs = "\n - ".join([

0 commit comments

Comments
 (0)