diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index 58c50c3..acdb81e 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -2477,58 +2477,33 @@ ModelInstanceState::SetStringInputBuffer( std::vector* responses, char* input_buffer, std::vector* string_ptrs) { + std::vector> str_list; // offset for each response size_t buffer_copy_offset = 0; for (size_t idx = 0; idx < expected_byte_sizes.size(); idx++) { const size_t expected_byte_size = expected_byte_sizes[idx]; const size_t expected_element_cnt = expected_element_cnts[idx]; - size_t element_cnt = 0; if ((*responses)[idx] != nullptr) { - size_t remaining_bytes = expected_byte_size; char* data_content = input_buffer + buffer_copy_offset; - // Continue if the remaining bytes may still contain size info - while (remaining_bytes >= sizeof(uint32_t)) { - if (element_cnt >= expected_element_cnt) { - RESPOND_AND_SET_NULL_IF_ERROR( - &((*responses)[idx]), - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("unexpected number of string elements ") + - std::to_string(element_cnt + 1) + " for inference input '" + - input_name + "', expecting " + - std::to_string(expected_element_cnt)) - .c_str())); - break; - } - - const uint32_t len = *(reinterpret_cast(data_content)); - remaining_bytes -= sizeof(uint32_t); + TRITONSERVER_Error* err = ValidateStringBuffer( + data_content, expected_byte_size, expected_element_cnt, + input_name.c_str(), &str_list); + // Set string values. + for (const auto& [addr, len] : str_list) { // Make first byte of size info 0, so that if there is string data // in front of it, the data becomes valid C string. - *data_content = 0; - data_content = data_content + sizeof(uint32_t); - if (len > remaining_bytes) { - RESPOND_AND_SET_NULL_IF_ERROR( - &((*responses)[idx]), - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("incomplete string data for inference input '") + - input_name + "', expecting string of length " + - std::to_string(len) + " but only " + - std::to_string(remaining_bytes) + " bytes available") - .c_str())); - break; - } else { - string_ptrs->push_back(data_content); - element_cnt++; - data_content = data_content + len; - remaining_bytes -= len; - } + *const_cast(addr - sizeof(uint32_t)) = 0; + string_ptrs->push_back(addr); } - } - FillStringData(string_ptrs, expected_element_cnt - element_cnt); + size_t element_cnt = str_list.size(); + if (err != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR(&((*responses)[idx]), err); + FillStringData(string_ptrs, expected_element_cnt - element_cnt); + } + str_list.clear(); + } buffer_copy_offset += expected_byte_size; } }