|
30 | 30 |
|
31 | 31 | from official.resnet import resnet_model
|
32 | 32 | from official.utils.arg_parsers import parsers
|
| 33 | +from official.utils.export import export |
33 | 34 | from official.utils.logging import hooks_helper
|
34 | 35 | from official.utils.logging import logger
|
35 | 36 |
|
@@ -219,7 +220,13 @@ def resnet_model_fn(features, labels, mode, model_class,
|
219 | 220 | }
|
220 | 221 |
|
221 | 222 | if mode == tf.estimator.ModeKeys.PREDICT:
|
222 |
| - return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) |
| 223 | + # Return the predictions and the specification for serving a SavedModel |
| 224 | + return tf.estimator.EstimatorSpec( |
| 225 | + mode=mode, |
| 226 | + predictions=predictions, |
| 227 | + export_outputs={ |
| 228 | + 'predict': tf.estimator.export.PredictOutput(predictions) |
| 229 | + }) |
223 | 230 |
|
224 | 231 | # Calculate loss, which includes softmax cross entropy and L2 regularization.
|
225 | 232 | cross_entropy = tf.losses.softmax_cross_entropy(
|
@@ -310,8 +317,20 @@ def validate_batch_size_for_multi_gpu(batch_size):
|
310 | 317 | raise ValueError(err)
|
311 | 318 |
|
312 | 319 |
|
313 |
| -def resnet_main(flags, model_function, input_function): |
314 |
| - """Shared main loop for ResNet Models.""" |
| 320 | +def resnet_main(flags, model_function, input_function, shape=None): |
| 321 | + """Shared main loop for ResNet Models. |
| 322 | +
|
| 323 | + Args: |
| 324 | + flags: FLAGS object that contains the params for running. See |
| 325 | + ResnetArgParser for created flags. |
| 326 | + model_function: the function that instantiates the Model and builds the |
| 327 | + ops for train/eval. This will be passed directly into the estimator. |
| 328 | + input_function: the function that processes the dataset and returns a |
| 329 | + dataset that the estimator can train on. This will be wrapped with |
| 330 | + all the relevant flags for running and passed to estimator. |
| 331 | + shape: list of ints representing the shape of the images used for training. |
| 332 | + This is only used if flags.export_dir is passed. |
| 333 | + """ |
315 | 334 |
|
316 | 335 | # Using the Winograd non-fused algorithms provides a small performance boost.
|
317 | 336 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
|
@@ -389,16 +408,34 @@ def input_fn_eval():
|
389 | 408 | if benchmark_logger:
|
390 | 409 | benchmark_logger.log_estimator_evaluation_result(eval_results)
|
391 | 410 |
|
| 411 | + if flags.export_dir is not None: |
| 412 | + warn_on_multi_gpu_export(flags.multi_gpu) |
| 413 | + |
| 414 | + # Exports a saved model for the given classifier. |
| 415 | + input_receiver_fn = export.build_tensor_serving_input_receiver_fn( |
| 416 | + shape, batch_size=flags.batch_size) |
| 417 | + classifier.export_savedmodel(flags.export_dir, input_receiver_fn) |
| 418 | + |
| 419 | + |
| 420 | +def warn_on_multi_gpu_export(multi_gpu=False): |
| 421 | + """For the time being, multi-GPU mode does not play nicely with exporting.""" |
| 422 | + if multi_gpu: |
| 423 | + tf.logging.warning( |
| 424 | + 'You are exporting a SavedModel while in multi-GPU mode. Note that ' |
| 425 | + 'the resulting SavedModel will require the same GPUs be available.' |
| 426 | + 'If you wish to serve the SavedModel from a different device, ' |
| 427 | + 'try exporting the SavedModel with multi-GPU mode turned off.') |
| 428 | + |
392 | 429 |
|
393 | 430 | class ResnetArgParser(argparse.ArgumentParser):
|
394 |
| - """Arguments for configuring and running a Resnet Model. |
395 |
| - """ |
| 431 | + """Arguments for configuring and running a Resnet Model.""" |
396 | 432 |
|
397 | 433 | def __init__(self, resnet_size_choices=None):
|
398 | 434 | super(ResnetArgParser, self).__init__(parents=[
|
399 | 435 | parsers.BaseParser(),
|
400 | 436 | parsers.PerformanceParser(),
|
401 | 437 | parsers.ImageModelParser(),
|
| 438 | + parsers.ExportParser(), |
402 | 439 | parsers.BenchmarkParser(),
|
403 | 440 | ])
|
404 | 441 |
|
|
0 commit comments