Skip to content

Commit 80296d0

Browse files
authored
refactor: Refactor string input checks (#104)
Refactor string input tensor checks
1 parent 515466c commit 80296d0

File tree

1 file changed

+14
-63
lines changed

1 file changed

+14
-63
lines changed

src/tensorflow.cc

Lines changed: 14 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -567,7 +567,6 @@ SetStringInputTensor(
567567
cudaStream_t stream, const char* host_policy_name)
568568
{
569569
bool cuda_copy = false;
570-
size_t element_idx = 0;
571570

572571
// For string data type, we always need to have the data on CPU so
573572
// that we can read string length and construct the string
@@ -582,8 +581,7 @@ SetStringInputTensor(
582581
&contiguous_buffer, stream, &cuda_copy);
583582
if (err != nullptr) {
584583
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
585-
FillStringTensor(
586-
tensor, tensor_offset + element_idx, request_element_cnt - element_idx);
584+
FillStringTensor(tensor, tensor_offset, request_element_cnt);
587585
free(contiguous_buffer);
588586
return cuda_copy;
589587
}
@@ -595,68 +593,21 @@ SetStringInputTensor(
595593
}
596594
#endif // TRITON_ENABLE_GPU
597595

598-
// Parse content and assign to 'tensor'. Each string in 'content'
599-
// is a 4-byte length followed by the string itself with no
600-
// null-terminator.
601-
while (content_byte_size >= sizeof(uint32_t)) {
602-
if (element_idx >= request_element_cnt) {
603-
RESPOND_AND_SET_NULL_IF_ERROR(
604-
response,
605-
TRITONSERVER_ErrorNew(
606-
TRITONSERVER_ERROR_INVALID_ARG,
607-
std::string(
608-
"unexpected number of string elements " +
609-
std::to_string(element_idx + 1) + " for inference input '" +
610-
name + "', expecting " + std::to_string(request_element_cnt))
611-
.c_str()));
612-
FillStringTensor(
613-
tensor, tensor_offset + element_idx,
614-
request_element_cnt - element_idx);
615-
free(contiguous_buffer);
616-
return cuda_copy;
617-
}
618-
619-
const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
620-
content += sizeof(uint32_t);
621-
content_byte_size -= sizeof(uint32_t);
622-
623-
if (content_byte_size < len) {
624-
RESPOND_AND_SET_NULL_IF_ERROR(
625-
response,
626-
TRITONSERVER_ErrorNew(
627-
TRITONSERVER_ERROR_INVALID_ARG,
628-
std::string(
629-
"incomplete string data for inference input '" +
630-
std::string(name) + "', expecting string of length " +
631-
std::to_string(len) + " but only " +
632-
std::to_string(content_byte_size) + " bytes available")
633-
.c_str()));
634-
FillStringTensor(
635-
tensor, tensor_offset + element_idx,
636-
request_element_cnt - element_idx);
637-
free(contiguous_buffer);
638-
return cuda_copy;
639-
}
596+
std::vector<std::pair<const char*, const uint32_t>> str_list;
597+
err = ValidateStringBuffer(
598+
content, content_byte_size, request_element_cnt, name, &str_list);
599+
// Set string values.
600+
for (size_t element_idx = 0; element_idx < str_list.size(); ++element_idx) {
601+
const auto& [addr, len] = str_list[element_idx];
602+
TRITONTF_TensorSetString(tensor, tensor_offset + element_idx, addr, len);
603+
}
640604

641-
TRITONTF_TensorSetString(tensor, tensor_offset + element_idx, content, len);
642-
content += len;
643-
content_byte_size -= len;
644-
element_idx++;
645-
}
646-
647-
if ((*response != nullptr) && (element_idx != request_element_cnt)) {
648-
RESPOND_AND_SET_NULL_IF_ERROR(
649-
response, TRITONSERVER_ErrorNew(
650-
TRITONSERVER_ERROR_INTERNAL,
651-
std::string(
652-
"expected " + std::to_string(request_element_cnt) +
653-
" strings for inference input '" + name + "', got " +
654-
std::to_string(element_idx))
655-
.c_str()));
605+
size_t element_cnt = str_list.size();
606+
if (err != nullptr) {
607+
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
656608
FillStringTensor(
657-
tensor, tensor_offset + element_idx, request_element_cnt - element_idx);
609+
tensor, tensor_offset + element_cnt, request_element_cnt - element_cnt);
658610
}
659-
660611
free(contiguous_buffer);
661612
return cuda_copy;
662613
}

0 commit comments

Comments
 (0)