Skip to content

Commit 5a1ed16

Browse files
committed
Refactor string input checks
1 parent 515466c commit 5a1ed16

File tree

1 file changed

+17
-60
lines changed

1 file changed

+17
-60
lines changed

src/tensorflow.cc

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,11 @@ SetStringInputTensor(
582582
&contiguous_buffer, stream, &cuda_copy);
583583
if (err != nullptr) {
584584
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
585-
FillStringTensor(
586-
tensor, tensor_offset + element_idx, request_element_cnt - element_idx);
585+
if (element_idx < request_element_cnt) {
586+
FillStringTensor(
587+
tensor, tensor_offset + element_idx,
588+
request_element_cnt - element_idx);
589+
}
587590
free(contiguous_buffer);
588591
return cuda_copy;
589592
}
@@ -595,68 +598,22 @@ SetStringInputTensor(
595598
}
596599
#endif // TRITON_ENABLE_GPU
597600

598-
// Parse content and assign to 'tensor'. Each string in 'content'
599-
// is a 4-byte length followed by the string itself with no
600-
// null-terminator.
601-
while (content_byte_size >= sizeof(uint32_t)) {
602-
if (element_idx >= request_element_cnt) {
603-
RESPOND_AND_SET_NULL_IF_ERROR(
604-
response,
605-
TRITONSERVER_ErrorNew(
606-
TRITONSERVER_ERROR_INVALID_ARG,
607-
std::string(
608-
"unexpected number of string elements " +
609-
std::to_string(element_idx + 1) + " for inference input '" +
610-
name + "', expecting " + std::to_string(request_element_cnt))
611-
.c_str()));
612-
FillStringTensor(
613-
tensor, tensor_offset + element_idx,
614-
request_element_cnt - element_idx);
615-
free(contiguous_buffer);
616-
return cuda_copy;
617-
}
618-
619-
const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
620-
content += sizeof(uint32_t);
621-
content_byte_size -= sizeof(uint32_t);
622-
623-
if (content_byte_size < len) {
624-
RESPOND_AND_SET_NULL_IF_ERROR(
625-
response,
626-
TRITONSERVER_ErrorNew(
627-
TRITONSERVER_ERROR_INVALID_ARG,
628-
std::string(
629-
"incomplete string data for inference input '" +
630-
std::string(name) + "', expecting string of length " +
631-
std::to_string(len) + " but only " +
632-
std::to_string(content_byte_size) + " bytes available")
633-
.c_str()));
634-
FillStringTensor(
635-
tensor, tensor_offset + element_idx,
636-
request_element_cnt - element_idx);
637-
free(contiguous_buffer);
638-
return cuda_copy;
639-
}
640-
601+
auto callback = [](TRITONTF_Tensor* tensor, const size_t tensor_offset,
602+
const size_t element_idx, const char* content,
603+
const uint32_t len) {
641604
TRITONTF_TensorSetString(tensor, tensor_offset + element_idx, content, len);
642-
content += len;
643-
content_byte_size -= len;
644-
element_idx++;
645-
}
646-
647-
if ((*response != nullptr) && (element_idx != request_element_cnt)) {
648-
RESPOND_AND_SET_NULL_IF_ERROR(
649-
response, TRITONSERVER_ErrorNew(
650-
TRITONSERVER_ERROR_INTERNAL,
651-
std::string(
652-
"expected " + std::to_string(request_element_cnt) +
653-
" strings for inference input '" + name + "', got " +
654-
std::to_string(element_idx))
655-
.c_str()));
605+
};
606+
auto fn = std::bind(
607+
callback, tensor, tensor_offset, std::placeholders::_1,
608+
std::placeholders::_2, std::placeholders::_3);
609+
610+
err = ValidateStringBuffer(
611+
content, content_byte_size, request_element_cnt, name, &element_idx, fn);
612+
if (err != nullptr) {
613+
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
656614
FillStringTensor(
657615
tensor, tensor_offset + element_idx, request_element_cnt - element_idx);
658616
}
659-
660617
free(contiguous_buffer);
661618
return cuda_copy;
662619
}

0 commit comments

Comments
 (0)