@@ -2755,12 +2755,29 @@ HTTPAPIServer::ParseJsonTritonIO(
27552755 } else {
27562756 const int64_t element_cnt = GetElementCount (shape_vec);
27572757
2758- // FIXME, element count should never be 0 or negative so
2759- // shouldn't we just return an error here?
27602758 if (element_cnt == 0 ) {
27612759 RETURN_IF_ERR (TRITONSERVER_InferenceRequestAppendInputData (
27622760 irequest, input_name, nullptr , 0 /* byte_size */ ,
27632761 TRITONSERVER_MEMORY_CPU, 0 /* memory_type_id */ ));
2762+ } else if (element_cnt == -2 ) {
2763+ // -2 indicates invalid dimension
2764+ return TRITONSERVER_ErrorNew (
2765+ TRITONSERVER_ERROR_INVALID_ARG,
2766+ std::string (
2767+ " invalid shape for input '" + std::string (input_name) +
2768+ " ': shape " + ShapeToString (shape_vec) +
2769+ " contains one or more invalid dimensions" )
2770+ .c_str ());
2771+ } else if (element_cnt == -3 ) {
2772+ // -3 indicates integer overflow
2773+ return TRITONSERVER_ErrorNew (
2774+ TRITONSERVER_ERROR_INVALID_ARG,
2775+ std::string (
2776+ " invalid shape for input '" + std::string (input_name) +
2777+ " ': shape " + ShapeToString (shape_vec) +
2778+ " causes total element count to exceed maximum size of " +
2779+ std::to_string (INT64_MAX))
2780+ .c_str ());
27642781 } else {
27652782 // JSON... presence of "data" already validated but still
27662783 // checking here. Flow in this endpoint needs to be
@@ -2773,7 +2790,22 @@ HTTPAPIServer::ParseJsonTritonIO(
27732790 if (dtype == TRITONSERVER_TYPE_BYTES) {
27742791 RETURN_IF_ERR (JsonBytesArrayByteSize (tensor_data, &byte_size));
27752792 } else {
2776- byte_size = element_cnt * TRITONSERVER_DataTypeByteSize (dtype);
2793+ const uint32_t type_byte_size =
2794+ TRITONSERVER_DataTypeByteSize (dtype);
2795+ if ((type_byte_size > 1 ) &&
2796+ (element_cnt > (INT64_MAX / type_byte_size))) {
2797+ return TRITONSERVER_ErrorNew (
2798+ TRITONSERVER_ERROR_INVALID_ARG,
2799+ std::string (
2800+ " byte size overflow for input '" +
2801+ std::string (input_name) + " ': element count (" +
2802+ std::to_string (element_cnt) + " ) * data type size (" +
2803+ std::to_string (type_byte_size) +
2804+ " ) exceeds maximum allowed size (" +
2805+ std::to_string (INT64_MAX) + " )" )
2806+ .c_str ());
2807+ }
2808+ byte_size = element_cnt * type_byte_size;
27772809 }
27782810
27792811 infer_req->serialized_data_ .emplace_back ();
0 commit comments