Skip to content

Commit 0d76fbf

Browse files
authored
refactor: Add string input checks (#136)
Add string input tensor checks
1 parent c852a5e commit 0d76fbf

File tree

1 file changed

+11
-57
lines changed

1 file changed

+11
-57
lines changed

src/libtorch.cc

Lines changed: 11 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,6 @@ SetStringInputTensor(
19111911
cudaStream_t stream, const char* host_policy_name)
19121912
{
19131913
bool cuda_copy = false;
1914-
size_t element_idx = 0;
19151914

19161915
// For string data type, we always need to have the data on CPU so
19171916
// that we can read string length and construct the string
@@ -1926,7 +1925,7 @@ SetStringInputTensor(
19261925
stream, &cuda_copy);
19271926
if (err != nullptr) {
19281927
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
1929-
FillStringTensor(input_list, request_element_cnt - element_idx);
1928+
FillStringTensor(input_list, request_element_cnt);
19301929
return cuda_copy;
19311930
}
19321931

@@ -1937,64 +1936,19 @@ SetStringInputTensor(
19371936
}
19381937
#endif // TRITON_ENABLE_GPU
19391938

1940-
// Parse content and assign to 'tensor'. Each string in 'content'
1941-
// is a 4-byte length followed by the string itself with no
1942-
// null-terminator.
1943-
while (content_byte_size >= sizeof(uint32_t)) {
1944-
if (element_idx >= request_element_cnt) {
1945-
RESPOND_AND_SET_NULL_IF_ERROR(
1946-
response,
1947-
TRITONSERVER_ErrorNew(
1948-
TRITONSERVER_ERROR_INVALID_ARG,
1949-
std::string(
1950-
"unexpected number of string elements " +
1951-
std::to_string(element_idx + 1) + " for inference input '" +
1952-
name + "', expecting " + std::to_string(request_element_cnt))
1953-
.c_str()));
1954-
return cuda_copy;
1955-
}
1956-
1957-
const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
1958-
content += sizeof(uint32_t);
1959-
content_byte_size -= sizeof(uint32_t);
1960-
1961-
if (content_byte_size < len) {
1962-
RESPOND_AND_SET_NULL_IF_ERROR(
1963-
response,
1964-
TRITONSERVER_ErrorNew(
1965-
TRITONSERVER_ERROR_INVALID_ARG,
1966-
std::string(
1967-
"incomplete string data for inference input '" +
1968-
std::string(name) + "', expecting string of length " +
1969-
std::to_string(len) + " but only " +
1970-
std::to_string(content_byte_size) + " bytes available")
1971-
.c_str()));
1972-
FillStringTensor(input_list, request_element_cnt - element_idx);
1973-
return cuda_copy;
1974-
}
1975-
1976-
// Set string value
1977-
input_list->push_back(std::string(content, len));
1978-
1979-
content += len;
1980-
content_byte_size -= len;
1981-
element_idx++;
1939+
std::vector<std::pair<const char*, const uint32_t>> str_list;
1940+
err = ValidateStringBuffer(
1941+
content, content_byte_size, request_element_cnt, name, &str_list);
1942+
// Set string values.
1943+
for (const auto& [addr, len] : str_list) {
1944+
input_list->push_back(std::string(addr, len));
19821945
}
19831946

1984-
if ((*response != nullptr) && (element_idx != request_element_cnt)) {
1985-
RESPOND_AND_SET_NULL_IF_ERROR(
1986-
response, TRITONSERVER_ErrorNew(
1987-
TRITONSERVER_ERROR_INTERNAL,
1988-
std::string(
1989-
"expected " + std::to_string(request_element_cnt) +
1990-
" strings for inference input '" + name + "', got " +
1991-
std::to_string(element_idx))
1992-
.c_str()));
1993-
if (element_idx < request_element_cnt) {
1994-
FillStringTensor(input_list, request_element_cnt - element_idx);
1995-
}
1947+
size_t element_cnt = str_list.size();
1948+
if (err != nullptr) {
1949+
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
1950+
FillStringTensor(input_list, request_element_cnt - element_cnt);
19961951
}
1997-
19981952
return cuda_copy;
19991953
}
20001954

0 commit comments

Comments
 (0)