@@ -102,26 +102,38 @@ Status LoadModel(const string& model_dir, const string& signature_key,
102102Status SetupInputs (tensorflow::Device* device,
103103 tensorflow::DeviceContext* device_context,
104104 int32_t batch_size,
105+ int32_t input_size,
105106 std::vector<tensorflow::TensorInfo>& input_info,
107+ bool input_from_device,
106108 std::vector<Tensor>* inputs) {
107109 tensorflow::AllocatorAttributes attr;
108110 tensorflow::Allocator* allocator = device->GetAllocator (attr);
109111
110112 std::vector<Tensor> inputs_device;
111- for (const auto & info : input_info) {
113+ for (auto & info : input_info) {
112114 // Set input batch size
113- auto shape = info.tensor_shape ();
114- shape.mutable_dim (0 )->set_size (batch_size);
115+ auto * shape = info.mutable_tensor_shape ();
116+ shape->mutable_dim (0 )->set_size (batch_size);
117+ for (size_t i = 1 ; i < shape->dim_size (); i++) {
118+ auto * dim = shape->mutable_dim (i);
119+ if (dim->size () < 0 ) {
120+ dim->set_size (input_size);
121+ }
122+ }
115123
116124 // Allocate memory and fill host tensor
117- Tensor input_host (info.dtype (), shape);
118- Tensor input_device (allocator, info.dtype (), shape);
125+ Tensor input_host (info.dtype (), * shape);
126+ Tensor input_device (allocator, info.dtype (), * shape);
119127 std::fill_n ((uint8_t *)input_host.data (), input_host.AllocatedBytes (), 1 );
120128
121129 // Copy from host to device
122- TF_RETURN_IF_ERROR (device_context->CopyCPUTensorToDeviceSync (
123- &input_host, device, &input_device));
124- inputs_device.push_back (input_device);
130+ if (input_from_device) {
131+ TF_RETURN_IF_ERROR (device_context->CopyCPUTensorToDeviceSync (
132+ &input_host, device, &input_device));
133+ inputs_device.push_back (input_device);
134+ } else {
135+ inputs_device.push_back (input_host);
136+ }
125137 }
126138 *inputs = inputs_device;
127139 return Status::OK ();
@@ -208,6 +220,7 @@ int main(int argc, char* argv[]) {
208220 string model_path = " /path/to/model/" ;
209221 string signature_key = " serving_default" ;
210222 int32_t batch_size = 64 ;
223+ int32_t input_size = 128 ;
211224 int32_t warmup_iters = 200 ;
212225 int32_t eval_iters = 800 ;
213226 bool input_from_device = true ;
@@ -217,6 +230,7 @@ int main(int argc, char* argv[]) {
217230 Flag (" model_path" , &model_path, " graph to be executed" ),
218231 Flag (" signature_key" , &signature_key, " the serving signature to use" ),
219232 Flag (" batch_size" , &batch_size, " batch size to use for inference" ),
233+ Flag (" input_size" , &input_size, " shape to use for -1 input dims" ),
220234 Flag (" warmup_iters" , &warmup_iters, " number of warmup iterations to run" ),
221235 Flag (" eval_iters" , &eval_iters, " number of timed iterations to run" ),
222236 Flag (" input_from_device" , &input_from_device, " use inputs from device, rather than host" ),
@@ -254,8 +268,8 @@ int main(int argc, char* argv[]) {
254268 // Create inputs and move to device
255269 // TODO: Measure H2D times over repeated calls and report metrics
256270 std::vector<Tensor> inputs_device;
257- TFTRT_ENSURE_OK (SetupInputs (device, device_context, batch_size, input_info,
258- &inputs_device));
271+ TFTRT_ENSURE_OK (SetupInputs (device, device_context, batch_size, input_size, input_info,
272+ input_from_device, &inputs_device));
259273
260274 // Configure to feed and fetch from device
261275 tensorflow::Session::CallableHandle handle;
0 commit comments