@@ -92,26 +92,38 @@ Status LoadModel(const string& model_dir, const string& signature_key,
9292Status SetupInputs (tensorflow::Device* device,
9393 tensorflow::DeviceContext* device_context,
9494 int32_t batch_size,
95+ int32_t input_size,
9596 std::vector<tensorflow::TensorInfo>& input_info,
97+ bool input_from_device,
9698 std::vector<Tensor>* inputs) {
9799 tensorflow::AllocatorAttributes attr;
98100 tensorflow::Allocator* allocator = device->GetAllocator (attr);
99101
100102 std::vector<Tensor> inputs_device;
101- for (const auto & info : input_info) {
103+ for (auto & info : input_info) {
102104 // Set input batch size
103- auto shape = info.tensor_shape ();
104- shape.mutable_dim (0 )->set_size (batch_size);
105+ auto * shape = info.mutable_tensor_shape ();
106+ shape->mutable_dim (0 )->set_size (batch_size);
107+ for (size_t i = 1 ; i < shape->dim_size (); i++) {
108+ auto * dim = shape->mutable_dim (i);
109+ if (dim->size () < 0 ) {
110+ dim->set_size (input_size);
111+ }
112+ }
105113
106114 // Allocate memory and fill host tensor
107- Tensor input_host (info.dtype (), shape);
108- Tensor input_device (allocator, info.dtype (), shape);
115+ Tensor input_host (info.dtype (), * shape);
116+ Tensor input_device (allocator, info.dtype (), * shape);
109117 std::fill_n ((uint8_t *)input_host.data (), input_host.AllocatedBytes (), 1 );
110118
111119 // Copy from host to device
112- TF_RETURN_IF_ERROR (device_context->CopyCPUTensorToDeviceSync (
113- &input_host, device, &input_device));
114- inputs_device.push_back (input_device);
120+ if (input_from_device) {
121+ TF_RETURN_IF_ERROR (device_context->CopyCPUTensorToDeviceSync (
122+ &input_host, device, &input_device));
123+ inputs_device.push_back (input_device);
124+ } else {
125+ inputs_device.push_back (input_host);
126+ }
115127 }
116128 *inputs = inputs_device;
117129 return Status::OK ();
@@ -167,6 +179,7 @@ int main(int argc, char* argv[]) {
167179 string model_path = " /path/to/model/" ;
168180 string signature_key = " serving_default" ;
169181 int32_t batch_size = 64 ;
182+ int32_t input_size = 128 ;
170183 int32_t warmup_iters = 200 ;
171184 int32_t eval_iters = 800 ;
172185 bool input_from_device = true ;
@@ -176,6 +189,7 @@ int main(int argc, char* argv[]) {
176189 Flag (" model_path" , &model_path, " graph to be executed" ),
177190 Flag (" signature_key" , &signature_key, " the serving signature to use" ),
178191 Flag (" batch_size" , &batch_size, " batch size to use for inference" ),
192+ Flag (" input_size" , &input_size, " shape to use for -1 input dims" ),
179193 Flag (" warmup_iters" , &warmup_iters, " number of warmup iterations to run" ),
180194 Flag (" eval_iters" , &eval_iters, " number of timed iterations to run" ),
181195 Flag (" input_from_device" , &input_from_device, " use inputs from device, rather than host" ),
@@ -213,8 +227,8 @@ int main(int argc, char* argv[]) {
213227 // Create inputs and move to device
214228 // TODO: Measure H2D times over repeated calls and report metrics
215229 std::vector<Tensor> inputs_device;
216- TFTRT_ENSURE_OK (SetupInputs (device, device_context, batch_size, input_info,
217- &inputs_device));
230+ TFTRT_ENSURE_OK (SetupInputs (device, device_context, batch_size, input_size, input_info,
231+ input_from_device, &inputs_device));
218232
219233 // Configure to feed and fetch from device
220234 tensorflow::Session::CallableHandle handle;
0 commit comments