|
22 | 22 |
|
23 | 23 | from tensorflow.python.compiler.tensorrt import trt_convert as trt |
24 | 24 |
|
| 25 | +from tensorflow.python.saved_model import signature_constants |
| 26 | +from tensorflow.python.saved_model import tag_constants |
| 27 | + |
25 | 28 | __all__ = ["BaseBenchmarkRunner"] |
26 | 29 |
|
27 | 30 |
|
@@ -100,28 +103,37 @@ def _get_graph_func(self): |
100 | 103 | returns: TF function that is ready to run for inference |
101 | 104 | """ |
102 | 105 |
|
103 | | - if not self._args.use_tftrt: |
| 106 | + def load_model_from_disk( |
| 107 | + path, |
| 108 | + tags=[tag_constants.SERVING], |
| 109 | + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY |
| 110 | + ): |
| 111 | + saved_model_loaded = tf.saved_model.load(export_dir=path, tags=tags) |
104 | 112 |
|
105 | | - with timed_section("Loading TensorFlow native model"): |
| 113 | + graph_func = saved_model_loaded.signatures[signature_key] |
106 | 114 |
|
107 | | - saved_model_loaded = tf.saved_model.load( |
108 | | - export_dir=self._args.input_saved_model_dir, |
109 | | - tags=self._args.model_tag.split(",") |
110 | | - ) |
| 115 | + # from tensorflow.python.framework import convert_to_constants |
| 116 | + # graph_func = convert_to_constants.convert_variables_to_constants_v2( |
| 117 | + # graph_func |
| 118 | + # ) |
| 119 | + |
| 120 | + # Known TF Issue: https://github.com/tensorflow/tensorflow/issues/37615#issuecomment-767804930 |
| 121 | + # it looks like if the original trackable object is released by |
| 122 | + # the Python garbage collector once it goes out of scope, and |
| 123 | + # the signature returned by the function does not maintain a |
| 124 | + # back-reference to the original loaded object. |
| 125 | + graph_func._backref_to_saved_model = saved_model_loaded |
111 | 126 |
|
112 | | - graph_func = saved_model_loaded.signatures[ |
113 | | - self._args.input_signature_key] |
114 | | - # from tensorflow.python.framework import convert_to_constants |
115 | | - # graph_func = convert_to_constants.convert_variables_to_constants_v2( |
116 | | - # graph_func |
117 | | - # ) |
| 127 | + return graph_func |
118 | 128 |
|
119 | | - # Known TF Issue: https://github.com/tensorflow/tensorflow/issues/37615#issuecomment-767804930 |
120 | | - # it looks like if the original trackable object is released by |
121 | | - # the Python garbage collector once it goes out of scope, and |
122 | | - # the signature returned by the function does not maintain a |
123 | | - # back-reference to the original loaded object. |
124 | | - graph_func._backref_to_saved_model = saved_model_loaded |
| 129 | + if not self._args.use_tftrt: |
| 130 | + |
| 131 | + with timed_section("Loading TensorFlow native model"): |
| 132 | + graph_func = load_model_from_disk( |
| 133 | + path=self._args.input_saved_model_dir, |
| 134 | + tags=self._args.model_tag.split(","), |
| 135 | + signature_key=self._args.input_signature_key |
| 136 | + ) |
125 | 137 |
|
126 | 138 | else: |
127 | 139 |
|
@@ -231,6 +243,13 @@ def engine_build_input_fn(num_batches, model_phase): |
231 | 243 | f"Converted graph saved to " |
232 | 244 | f"`{self._args.output_saved_model_dir}`" |
233 | 245 | ) |
| 246 | + # Engine cache is cleared while saving, we have to reload. |
| 247 | + # Failing to do so, would force TF-TRT to rebuild |
| 248 | + del converter |
| 249 | + del graph_func |
| 250 | + graph_func = load_model_from_disk( |
| 251 | + self._args.output_saved_model_dir |
| 252 | + ) |
234 | 253 |
|
235 | 254 | if isinstance(graph_func.structured_outputs, (tuple, list)): |
236 | 255 | savedmodel_outputs = "\n - ".join([ |
|
0 commit comments