Skip to content

Commit d6750e8

Browse files
authored
Fix: Update element count handling (#8182)
Ensures requests with overly large or invalid element counts are rejected with a relevant error message.
1 parent 4fc0a5e commit d6750e8

File tree

5 files changed

+177
-7
lines changed

5 files changed

+177
-7
lines changed

qa/L0_http/http_test.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,104 @@ def test_descriptive_status_code(self):
236236
)
237237
t.join()
238238

239+
def test_buffer_size_overflow(self):
240+
model = "onnx_zero_1_float32"
241+
242+
# Test for overflow within GetElementCount()
243+
payload1 = {
244+
"inputs": [
245+
{
246+
"name": "INPUT0",
247+
"shape": [
248+
2**4,
249+
2**60 + 2,
250+
], # This evaluates to 2^64 + 32 during GetElementCount()
251+
"datatype": "FP32",
252+
"data": [1.0],
253+
}
254+
]
255+
}
256+
257+
# Test for overflow with type_byte_size multiplication
258+
payload2 = {
259+
"inputs": [
260+
{
261+
"name": "INPUT0",
262+
"shape": [
263+
2**2,
264+
2**60 + 2,
265+
], # This evaluates to 2^64 + 32 during type_byte_size multiplication since FP32 is 4 bytes
266+
"datatype": "FP32",
267+
"data": [1.0],
268+
}
269+
]
270+
}
271+
272+
# Send request and expect a 400 error with specific overflow message
273+
headers = {"Content-Type": "application/json"}
274+
275+
# Test the first payload (GetElementCount overflow)
276+
r1 = requests.post(self._get_infer_url(model), json=payload1, headers=headers)
277+
278+
self.assertEqual(
279+
400,
280+
r1.status_code,
281+
"Expected error code 400 for GetElementCount overflow check; got: {}".format(
282+
r1.status_code
283+
),
284+
)
285+
286+
error_message1 = r1.content.decode()
287+
self.assertIn(
288+
"causes total element count to exceed maximum size of", error_message1
289+
)
290+
291+
# Test the second payload (type_byte_size multiplication overflow)
292+
r2 = requests.post(self._get_infer_url(model), json=payload2, headers=headers)
293+
294+
self.assertEqual(
295+
400,
296+
r2.status_code,
297+
"Expected error code 400 for type_byte_size multiplication overflow check; got: {}".format(
298+
r2.status_code
299+
),
300+
)
301+
302+
error_message2 = r2.content.decode()
303+
self.assertIn("byte size overflow for input", error_message2)
304+
305+
def test_negative_dimensions(self):
306+
model = "onnx_zero_1_float32"
307+
308+
payload = {
309+
"inputs": [
310+
{
311+
"name": "INPUT0",
312+
"shape": [2, -5], # Negative dimension should be invalid
313+
"datatype": "FP32",
314+
"data": [1.0],
315+
}
316+
]
317+
}
318+
319+
# Send request and expect a 500 error
320+
headers = {"Content-Type": "application/json"}
321+
r = requests.post(self._get_infer_url(model), json=payload, headers=headers)
322+
323+
self.assertEqual(
324+
500,
325+
r.status_code,
326+
"Expected error code 500 for negative dimension; got: {}".format(
327+
r.status_code
328+
),
329+
)
330+
331+
error_message = r.content.decode()
332+
self.assertIn(
333+
"Unable to parse 'shape': attempt to access JSON non-unsigned-integer as unsigned-integer",
334+
error_message,
335+
)
336+
239337
def test_loading_large_invalid_model(self):
240338
# Generate large base64 encoded data
241339
data_length = 1 << 31

qa/L0_http/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ fi
630630

631631
TEST_RESULT_FILE='test_results.txt'
632632
PYTHON_TEST=http_test.py
633-
EXPECTED_NUM_TESTS=11
633+
EXPECTED_NUM_TESTS=13
634634
set +e
635635
python $PYTHON_TEST >$CLIENT_LOG 2>&1
636636
if [ $? -ne 0 ]; then

src/common.cc

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -80,6 +80,31 @@ GetEnvironmentVariableOrDefault(
8080
return value ? value : default_value;
8181
}
8282

83+
std::string
84+
ShapeToString(const int64_t* dims, const size_t dims_count)
85+
{
86+
bool first = true;
87+
88+
std::string str("[");
89+
for (size_t i = 0; i < dims_count; ++i) {
90+
const int64_t dim = dims[i];
91+
if (!first) {
92+
str += ",";
93+
}
94+
str += std::to_string(dim);
95+
first = false;
96+
}
97+
98+
str += "]";
99+
return str;
100+
}
101+
102+
std::string
103+
ShapeToString(const std::vector<int64_t>& shape)
104+
{
105+
return ShapeToString(shape.data(), shape.size());
106+
}
107+
83108
int64_t
84109
GetElementCount(const std::vector<int64_t>& dims)
85110
{
@@ -88,12 +113,20 @@ GetElementCount(const std::vector<int64_t>& dims)
88113
for (auto dim : dims) {
89114
if (dim == WILDCARD_DIM) {
90115
return -1;
116+
} else if (dim < 0) { // invalid dim
117+
return -2;
118+
} else if (dim == 0) {
119+
return 0;
91120
}
92121

93122
if (first) {
94123
cnt = dim;
95124
first = false;
96125
} else {
126+
// Check for overflow before multiplication
127+
if (cnt > (INT64_MAX / dim)) {
128+
return -3;
129+
}
97130
cnt *= dim;
98131
}
99132
}

src/common.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,18 @@ std::string GetEnvironmentVariableOrDefault(
158158
/// Get the number of elements in a shape.
159159
///
160160
/// \param dims The shape.
161-
/// \return The number of elements, or -1 if the number of elements
161+
/// \return The number of elements, -1 if the number of elements
162162
/// cannot be determined because the shape contains one or more
163-
/// wildcard dimensions.
163+
/// wildcard dimensions, -2 if the shape contains an invalid dim,
164+
/// or -3 if the number is too large to represent as an int64_t.
164165
int64_t GetElementCount(const std::vector<int64_t>& dims);
165166

167+
/// Convert shape to string representation.
168+
///
169+
/// \param shape The shape as a vector.
170+
/// \return The string representation of the shape.
171+
std::string ShapeToString(const std::vector<int64_t>& shape);
172+
166173
/// Returns if 'vec' contains 'str'.
167174
///
168175
/// \param vec The vector of strings to search.

src/http_server.cc

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2755,12 +2755,29 @@ HTTPAPIServer::ParseJsonTritonIO(
27552755
} else {
27562756
const int64_t element_cnt = GetElementCount(shape_vec);
27572757

2758-
// FIXME, element count should never be 0 or negative so
2759-
// shouldn't we just return an error here?
27602758
if (element_cnt == 0) {
27612759
RETURN_IF_ERR(TRITONSERVER_InferenceRequestAppendInputData(
27622760
irequest, input_name, nullptr, 0 /* byte_size */,
27632761
TRITONSERVER_MEMORY_CPU, 0 /* memory_type_id */));
2762+
} else if (element_cnt == -2) {
2763+
// -2 indicates invalid dimension
2764+
return TRITONSERVER_ErrorNew(
2765+
TRITONSERVER_ERROR_INVALID_ARG,
2766+
std::string(
2767+
"invalid shape for input '" + std::string(input_name) +
2768+
"': shape " + ShapeToString(shape_vec) +
2769+
" contains one or more invalid dimensions")
2770+
.c_str());
2771+
} else if (element_cnt == -3) {
2772+
// -3 indicates integer overflow
2773+
return TRITONSERVER_ErrorNew(
2774+
TRITONSERVER_ERROR_INVALID_ARG,
2775+
std::string(
2776+
"invalid shape for input '" + std::string(input_name) +
2777+
"': shape " + ShapeToString(shape_vec) +
2778+
" causes total element count to exceed maximum size of " +
2779+
std::to_string(INT64_MAX))
2780+
.c_str());
27642781
} else {
27652782
// JSON... presence of "data" already validated but still
27662783
// checking here. Flow in this endpoint needs to be
@@ -2773,7 +2790,22 @@ HTTPAPIServer::ParseJsonTritonIO(
27732790
if (dtype == TRITONSERVER_TYPE_BYTES) {
27742791
RETURN_IF_ERR(JsonBytesArrayByteSize(tensor_data, &byte_size));
27752792
} else {
2776-
byte_size = element_cnt * TRITONSERVER_DataTypeByteSize(dtype);
2793+
const uint32_t type_byte_size =
2794+
TRITONSERVER_DataTypeByteSize(dtype);
2795+
if ((type_byte_size > 1) &&
2796+
(element_cnt > (INT64_MAX / type_byte_size))) {
2797+
return TRITONSERVER_ErrorNew(
2798+
TRITONSERVER_ERROR_INVALID_ARG,
2799+
std::string(
2800+
"byte size overflow for input '" +
2801+
std::string(input_name) + "': element count (" +
2802+
std::to_string(element_cnt) + ") * data type size (" +
2803+
std::to_string(type_byte_size) +
2804+
") exceeds maximum allowed size (" +
2805+
std::to_string(INT64_MAX) + ")")
2806+
.c_str());
2807+
}
2808+
byte_size = element_cnt * type_byte_size;
27772809
}
27782810

27792811
infer_req->serialized_data_.emplace_back();

0 commit comments

Comments
 (0)