1
- // Copyright 2020-2023 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ // Copyright 2020-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
//
3
3
// Redistribution and use in source and binary forms, with or without
4
4
// modification, are permitted provided that the following conditions
@@ -567,7 +567,6 @@ SetStringInputTensor(
567
567
cudaStream_t stream, const char * host_policy_name)
568
568
{
569
569
bool cuda_copy = false ;
570
- size_t element_idx = 0 ;
571
570
572
571
// For string data type, we always need to have the data on CPU so
573
572
// that we can read string length and construct the string
@@ -582,8 +581,7 @@ SetStringInputTensor(
582
581
&contiguous_buffer, stream, &cuda_copy);
583
582
if (err != nullptr ) {
584
583
RESPOND_AND_SET_NULL_IF_ERROR (response, err);
585
- FillStringTensor (
586
- tensor, tensor_offset + element_idx, request_element_cnt - element_idx);
584
+ FillStringTensor (tensor, tensor_offset, request_element_cnt);
587
585
free (contiguous_buffer);
588
586
return cuda_copy;
589
587
}
@@ -595,68 +593,21 @@ SetStringInputTensor(
595
593
}
596
594
#endif // TRITON_ENABLE_GPU
597
595
598
- // Parse content and assign to 'tensor'. Each string in 'content'
599
- // is a 4-byte length followed by the string itself with no
600
- // null-terminator.
601
- while (content_byte_size >= sizeof (uint32_t )) {
602
- if (element_idx >= request_element_cnt) {
603
- RESPOND_AND_SET_NULL_IF_ERROR (
604
- response,
605
- TRITONSERVER_ErrorNew (
606
- TRITONSERVER_ERROR_INVALID_ARG,
607
- std::string (
608
- " unexpected number of string elements " +
609
- std::to_string (element_idx + 1 ) + " for inference input '" +
610
- name + " ', expecting " + std::to_string (request_element_cnt))
611
- .c_str ()));
612
- FillStringTensor (
613
- tensor, tensor_offset + element_idx,
614
- request_element_cnt - element_idx);
615
- free (contiguous_buffer);
616
- return cuda_copy;
617
- }
618
-
619
- const uint32_t len = *(reinterpret_cast <const uint32_t *>(content));
620
- content += sizeof (uint32_t );
621
- content_byte_size -= sizeof (uint32_t );
622
-
623
- if (content_byte_size < len) {
624
- RESPOND_AND_SET_NULL_IF_ERROR (
625
- response,
626
- TRITONSERVER_ErrorNew (
627
- TRITONSERVER_ERROR_INVALID_ARG,
628
- std::string (
629
- " incomplete string data for inference input '" +
630
- std::string (name) + " ', expecting string of length " +
631
- std::to_string (len) + " but only " +
632
- std::to_string (content_byte_size) + " bytes available" )
633
- .c_str ()));
634
- FillStringTensor (
635
- tensor, tensor_offset + element_idx,
636
- request_element_cnt - element_idx);
637
- free (contiguous_buffer);
638
- return cuda_copy;
639
- }
596
+ std::vector<std::pair<const char *, const uint32_t >> str_list;
597
+ err = ValidateStringBuffer (
598
+ content, content_byte_size, request_element_cnt, name, &str_list);
599
+ // Set string values.
600
+ for (size_t element_idx = 0 ; element_idx < str_list.size (); ++element_idx) {
601
+ const auto & [addr, len] = str_list[element_idx];
602
+ TRITONTF_TensorSetString (tensor, tensor_offset + element_idx, addr, len);
603
+ }
640
604
641
- TRITONTF_TensorSetString (tensor, tensor_offset + element_idx, content, len);
642
- content += len;
643
- content_byte_size -= len;
644
- element_idx++;
645
- }
646
-
647
- if ((*response != nullptr ) && (element_idx != request_element_cnt)) {
648
- RESPOND_AND_SET_NULL_IF_ERROR (
649
- response, TRITONSERVER_ErrorNew (
650
- TRITONSERVER_ERROR_INTERNAL,
651
- std::string (
652
- " expected " + std::to_string (request_element_cnt) +
653
- " strings for inference input '" + name + " ', got " +
654
- std::to_string (element_idx))
655
- .c_str ()));
605
+ size_t element_cnt = str_list.size ();
606
+ if (err != nullptr ) {
607
+ RESPOND_AND_SET_NULL_IF_ERROR (response, err);
656
608
FillStringTensor (
657
- tensor, tensor_offset + element_idx , request_element_cnt - element_idx );
609
+ tensor, tensor_offset + element_cnt , request_element_cnt - element_cnt );
658
610
}
659
-
660
611
free (contiguous_buffer);
661
612
return cuda_copy;
662
613
}
0 commit comments