Skip to content

Commit d9feb72

Browse files
committed
Revert "refactor: Add string input checks (#136)"
This reverts commit 0d76fbf.
1 parent 7425a42 commit d9feb72

File tree

1 file changed

+57
-11
lines changed

1 file changed

+57
-11
lines changed

src/libtorch.cc

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,7 @@ SetStringInputTensor(
19141914
cudaStream_t stream, const char* host_policy_name)
19151915
{
19161916
bool cuda_copy = false;
1917+
size_t element_idx = 0;
19171918

19181919
// For string data type, we always need to have the data on CPU so
19191920
// that we can read string length and construct the string
@@ -1928,7 +1929,7 @@ SetStringInputTensor(
19281929
stream, &cuda_copy);
19291930
if (err != nullptr) {
19301931
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
1931-
FillStringTensor(input_list, request_element_cnt);
1932+
FillStringTensor(input_list, request_element_cnt - element_idx);
19321933
return cuda_copy;
19331934
}
19341935

@@ -1939,19 +1940,64 @@ SetStringInputTensor(
19391940
}
19401941
#endif // TRITON_ENABLE_GPU
19411942

1942-
std::vector<std::pair<const char*, const uint32_t>> str_list;
1943-
err = ValidateStringBuffer(
1944-
content, content_byte_size, request_element_cnt, name, &str_list);
1945-
// Set string values.
1946-
for (const auto& [addr, len] : str_list) {
1947-
input_list->push_back(std::string(addr, len));
1943+
// Parse content and assign to 'tensor'. Each string in 'content'
1944+
// is a 4-byte length followed by the string itself with no
1945+
// null-terminator.
1946+
while (content_byte_size >= sizeof(uint32_t)) {
1947+
if (element_idx >= request_element_cnt) {
1948+
RESPOND_AND_SET_NULL_IF_ERROR(
1949+
response,
1950+
TRITONSERVER_ErrorNew(
1951+
TRITONSERVER_ERROR_INVALID_ARG,
1952+
std::string(
1953+
"unexpected number of string elements " +
1954+
std::to_string(element_idx + 1) + " for inference input '" +
1955+
name + "', expecting " + std::to_string(request_element_cnt))
1956+
.c_str()));
1957+
return cuda_copy;
1958+
}
1959+
1960+
const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
1961+
content += sizeof(uint32_t);
1962+
content_byte_size -= sizeof(uint32_t);
1963+
1964+
if (content_byte_size < len) {
1965+
RESPOND_AND_SET_NULL_IF_ERROR(
1966+
response,
1967+
TRITONSERVER_ErrorNew(
1968+
TRITONSERVER_ERROR_INVALID_ARG,
1969+
std::string(
1970+
"incomplete string data for inference input '" +
1971+
std::string(name) + "', expecting string of length " +
1972+
std::to_string(len) + " but only " +
1973+
std::to_string(content_byte_size) + " bytes available")
1974+
.c_str()));
1975+
FillStringTensor(input_list, request_element_cnt - element_idx);
1976+
return cuda_copy;
1977+
}
1978+
1979+
// Set string value
1980+
input_list->push_back(std::string(content, len));
1981+
1982+
content += len;
1983+
content_byte_size -= len;
1984+
element_idx++;
19481985
}
19491986

1950-
size_t element_cnt = str_list.size();
1951-
if (err != nullptr) {
1952-
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
1953-
FillStringTensor(input_list, request_element_cnt - element_cnt);
1987+
if ((*response != nullptr) && (element_idx != request_element_cnt)) {
1988+
RESPOND_AND_SET_NULL_IF_ERROR(
1989+
response, TRITONSERVER_ErrorNew(
1990+
TRITONSERVER_ERROR_INTERNAL,
1991+
std::string(
1992+
"expected " + std::to_string(request_element_cnt) +
1993+
" strings for inference input '" + name + "', got " +
1994+
std::to_string(element_idx))
1995+
.c_str()));
1996+
if (element_idx < request_element_cnt) {
1997+
FillStringTensor(input_list, request_element_cnt - element_idx);
1998+
}
19541999
}
2000+
19552001
return cuda_copy;
19562002
}
19572003

0 commit comments

Comments
 (0)