Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 35ee008

Browse files
committed
Fix negative input shapes for transformer and object detection models
1 parent 0e85879 commit 35ee008

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

tftrt/benchmarking-cpp/main.cc

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,38 @@ Status LoadModel(const string& model_dir, const string& signature_key,
9292
Status SetupInputs(tensorflow::Device* device,
9393
tensorflow::DeviceContext* device_context,
9494
int32_t batch_size,
95+
int32_t input_size,
9596
std::vector<tensorflow::TensorInfo>& input_info,
97+
bool input_from_device,
9698
std::vector<Tensor>* inputs) {
9799
tensorflow::AllocatorAttributes attr;
98100
tensorflow::Allocator* allocator = device->GetAllocator(attr);
99101

100102
std::vector<Tensor> inputs_device;
101-
for (const auto& info : input_info) {
103+
for (auto& info : input_info) {
102104
// Set input batch size
103-
auto shape = info.tensor_shape();
104-
shape.mutable_dim(0)->set_size(batch_size);
105+
auto* shape = info.mutable_tensor_shape();
106+
shape->mutable_dim(0)->set_size(batch_size);
107+
for (size_t i = 1; i < shape->dim_size(); i++) {
108+
auto* dim = shape->mutable_dim(i);
109+
if (dim->size() < 0) {
110+
dim->set_size(input_size);
111+
}
112+
}
105113

106114
// Allocate memory and fill host tensor
107-
Tensor input_host(info.dtype(), shape);
108-
Tensor input_device(allocator, info.dtype(), shape);
115+
Tensor input_host(info.dtype(), *shape);
116+
Tensor input_device(allocator, info.dtype(), *shape);
109117
std::fill_n((uint8_t*)input_host.data(), input_host.AllocatedBytes(), 1);
110118

111119
// Copy from host to device
112-
TF_RETURN_IF_ERROR(device_context->CopyCPUTensorToDeviceSync(
113-
&input_host, device, &input_device));
114-
inputs_device.push_back(input_device);
120+
if (input_from_device) {
121+
TF_RETURN_IF_ERROR(device_context->CopyCPUTensorToDeviceSync(
122+
&input_host, device, &input_device));
123+
inputs_device.push_back(input_device);
124+
} else {
125+
inputs_device.push_back(input_host);
126+
}
115127
}
116128
*inputs = inputs_device;
117129
return Status::OK();
@@ -167,6 +179,7 @@ int main(int argc, char* argv[]) {
167179
string model_path = "/path/to/model/";
168180
string signature_key = "serving_default";
169181
int32_t batch_size = 64;
182+
int32_t input_size = 128;
170183
int32_t warmup_iters = 200;
171184
int32_t eval_iters = 800;
172185
bool input_from_device = true;
@@ -176,6 +189,7 @@ int main(int argc, char* argv[]) {
176189
Flag("model_path", &model_path, "graph to be executed"),
177190
Flag("signature_key", &signature_key, "the serving signature to use"),
178191
Flag("batch_size", &batch_size, "batch size to use for inference"),
192+
Flag("input_size", &input_size, "shape to use for -1 input dims"),
179193
Flag("warmup_iters", &warmup_iters, "number of warmup iterations to run"),
180194
Flag("eval_iters", &eval_iters, "number of timed iterations to run"),
181195
Flag("input_from_device", &input_from_device, "use inputs from device, rather than host"),
@@ -213,8 +227,8 @@ int main(int argc, char* argv[]) {
213227
// Create inputs and move to device
214228
// TODO: Measure H2D times over repeated calls and report metrics
215229
std::vector<Tensor> inputs_device;
216-
TFTRT_ENSURE_OK(SetupInputs(device, device_context, batch_size, input_info,
217-
&inputs_device));
230+
TFTRT_ENSURE_OK(SetupInputs(device, device_context, batch_size, input_size, input_info,
231+
input_from_device, &inputs_device));
218232

219233
// Configure to feed and fetch from device
220234
tensorflow::Session::CallableHandle handle;

0 commit comments

Comments
 (0)