Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit b0b0e11

Browse files
author
DEKHTIARJonathan
committed
Adding support for NGC resnet50 v1.5
1 parent ea8f038 commit b0b0e11

File tree

4 files changed

+58
-3
lines changed

4 files changed

+58
-3
lines changed

tftrt/examples/image_classification/image_classification.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ def __init__(self):
5555
'the model')
5656

5757
self._parser.add_argument('--preprocess_method', type=str,
58-
choices=['vgg', 'inception'], default='vgg',
58+
choices=['vgg', 'inception',
59+
'resnet50_v1_5_tf1_ngc_preprocess'
60+
],
61+
default='vgg',
5962
help='The image preprocessing method used in '
6063
'dataloading.')
6164

62-
6365
class BenchmarkRunner(BaseBenchmarkRunner):
6466

6567
ACCURACY_METRIC_NAME = "accuracy"
@@ -107,6 +109,8 @@ def get_preprocess_fn(preprocess_method, input_size):
107109
preprocess_fn = preprocessing.vgg_preprocess
108110
elif preprocess_method == 'inception':
109111
preprocess_fn = preprocessing.inception_preprocess
112+
elif preprocess_method == 'resnet50_v1_5_tf1_ngc_preprocess':
113+
preprocess_fn = preprocessing.resnet50_v1_5_tf1_ngc_preprocess
110114
else:
111115
raise ValueError(
112116
'Invalid preprocessing method {}'.format(preprocess_method)

tftrt/examples/image_classification/preprocessing.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,43 @@ def vgg_preprocess(image, output_height, output_width):
188188
image = _central_crop([image], output_height, output_width)[0]
189189
image.set_shape([output_height, output_width, 3])
190190
image = tf.cast(image, tf.float32)
191-
return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
191+
return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
192+
193+
194+
def resnet50_v1_5_tf1_ngc_preprocess(image,
195+
height,
196+
width,
197+
central_fraction=0.875,
198+
scope=None,
199+
central_crop=False):
200+
"""Prepare one image for evaluation.
201+
If height and width are specified it would output an image with that size by
202+
applying resize_bilinear.
203+
If central_fraction is specified it would crop the central fraction of the
204+
input image.
205+
Args:
206+
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
207+
[0, 1], otherwise it would converted to tf.float32 assuming that the range
208+
is [0, MAX], where MAX is largest positive representable number for
209+
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
210+
height: integer
211+
width: integer
212+
central_fraction: Optional Float, fraction of the image to crop.
213+
scope: Optional scope for name_scope.
214+
central_crop: Enable central cropping of images during preprocessing for
215+
evaluation.
216+
Returns:
217+
3-D float Tensor of prepared image.
218+
"""
219+
if image.dtype != tf.float32:
220+
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
221+
# Crop the central region of the image with an area containing 87.5% of
222+
# the original image.
223+
image = tf.image.central_crop(image, central_fraction=central_fraction)
224+
if height and width:
225+
# Resize the image to the specified height and width.
226+
image = tf.expand_dims(image, 0)
227+
image = tf.image.resize(image, [height, width])
228+
image = tf.squeeze(image, [0])
229+
image = image * 255
230+
return image

tftrt/examples/image_classification/scripts/base_script.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ case ${MODEL_NAME} in
7575
NUM_CLASSES=1000
7676
;;
7777

78+
"resnet50-v1.5_tf1_ngc" )
79+
NUM_CLASSES=1000
80+
OUTPUT_TENSOR_IDX_FLAG="--output_tensor_indices=0"
81+
OUTPUT_TENSOR_NAME_FLAG="--output_tensor_names=classes"
82+
PREPROCESS_METHOD="resnet50_v1_5_tf1_ngc_preprocess"
83+
;;
84+
7885
"resnet50v2_backbone" | "resnet50v2_sparse_backbone" )
7986
INPUT_SIZE=256
8087
;;
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
4+
5+
bash ${SCRIPT_DIR}/base_script.sh --model_name="resnet50-v1.5_tf1_ngc" ${@}

0 commit comments

Comments
 (0)