-
Notifications
You must be signed in to change notification settings - Fork 246
feat: Client-side input shape/element validation #742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
3178f99
7210d00
9c2941b
b4c6a17
e5e6b7e
07059a6
8b699c0
60f3f52
2a5c507
6b56c3b
223b9d8
a9a2c1c
a584741
5889b8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| // Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| // Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| // | ||
| // Redistribution and use in source and binary forms, with or without | ||
| // modification, are permitted provided that the following conditions | ||
|
|
@@ -26,6 +26,10 @@ | |
|
|
||
| #include "common.h" | ||
|
|
||
| #include <numeric> | ||
|
|
||
| #include "triton/common/model_config.h" | ||
|
|
||
| namespace triton { namespace client { | ||
|
|
||
| //============================================================================== | ||
|
|
@@ -232,6 +236,113 @@ InferInput::SetBinaryData(const bool binary_data) | |
| return Error::Success; | ||
| } | ||
|
|
||
| Error | ||
| InferInput::GetStringCount(size_t* str_cnt) const | ||
| { | ||
| int64_t str_checked = 0; | ||
| size_t remaining_str_size = 0; | ||
|
|
||
| size_t next_buf_idx = 0; | ||
| const size_t buf_cnt = bufs_.size(); | ||
|
|
||
| const uint8_t* buf = nullptr; | ||
| size_t remaining_buf_size = 0; | ||
|
|
||
| // Validate elements until all buffers have been fully processed. | ||
yinggeh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| while (remaining_buf_size || next_buf_idx < buf_cnt) { | ||
| // Get the next buf if not currently processing one. | ||
| if (!remaining_buf_size) { | ||
| // Reset remaining buf size and pointers for next buf. | ||
| buf = bufs_[next_buf_idx]; | ||
| remaining_buf_size = buf_byte_sizes_[next_buf_idx]; | ||
| next_buf_idx++; | ||
| } | ||
|
|
||
| constexpr size_t kStringSizeIndicator = sizeof(uint32_t); | ||
| // Get the next element if not currently processing one. | ||
| if (!remaining_str_size) { | ||
| // FIXME: Assume the string element's byte size indicator is not spread | ||
| // across buf boundaries for simplicity. Also needs better log msg. | ||
| if (remaining_buf_size < kStringSizeIndicator) { | ||
| return Error("element byte size indicator exceeds the end of the buf."); | ||
| } | ||
|
|
||
| // Start the next element and reset the remaining element size. | ||
| remaining_str_size = *(reinterpret_cast<const uint32_t*>(buf)); | ||
| str_checked++; | ||
|
|
||
| // Advance pointer and remainder by the indicator size. | ||
| buf += kStringSizeIndicator; | ||
| remaining_buf_size -= kStringSizeIndicator; | ||
| } | ||
|
|
||
| // If the remaining buf fits it: consume the rest of the element, proceed | ||
| // to the next element. | ||
| if (remaining_buf_size >= remaining_str_size) { | ||
| buf += remaining_str_size; | ||
| remaining_buf_size -= remaining_str_size; | ||
| remaining_str_size = 0; | ||
| } | ||
| // Otherwise the remaining element is larger: consume the rest of the | ||
| // buf, proceed to the next buf. | ||
| else { | ||
| remaining_str_size -= remaining_buf_size; | ||
| remaining_buf_size = 0; | ||
| } | ||
| } | ||
|
|
||
| // FIXME: If more than expected, should stop earlier | ||
| // Validate the number of processed elements exactly match expectations. | ||
| *str_cnt = str_checked; | ||
| return Error::Success; | ||
| } | ||
|
|
||
| Error | ||
| InferInput::ValidateData() const | ||
|
||
| { | ||
| inference::DataType datatype = | ||
| triton::common::ProtocolStringToDataType(datatype_); | ||
| if (io_type_ == SHARED_MEMORY) { | ||
| if (datatype == inference::DataType::TYPE_STRING) { | ||
| // TODO Didn't find any shm and BYTES inputs inference example | ||
| } else { | ||
| int64_t expected_byte_size = | ||
| triton::common::GetByteSize(datatype, shape_); | ||
| if ((int64_t)byte_size_ != expected_byte_size) { | ||
| return Error( | ||
| "'" + name_ + "' got unexpected byte size " + | ||
| std::to_string(byte_size_) + ", expected " + | ||
| std::to_string(expected_byte_size)); | ||
| } | ||
| } | ||
| } else { | ||
| if (datatype == inference::DataType::TYPE_STRING) { | ||
| int64_t expected_str_cnt = triton::common::GetElementCount(shape_); | ||
| size_t str_cnt; | ||
| Error err = GetStringCount(&str_cnt); | ||
| if (!err.IsOk()) { | ||
| return err; | ||
| } | ||
| if ((int64_t)str_cnt != expected_str_cnt) { | ||
| return Error( | ||
| "'" + name_ + "' got unexpected string count " + | ||
| std::to_string(str_cnt) + ", expected " + | ||
| std::to_string(expected_str_cnt)); | ||
| } | ||
| } else { | ||
| int64_t expected_byte_size = | ||
| triton::common::GetByteSize(datatype, shape_); | ||
| if ((int64_t)byte_size_ != expected_byte_size) { | ||
| return Error( | ||
| "'" + name_ + "' got unexpected byte size " + | ||
| std::to_string(byte_size_) + ", expected " + | ||
| std::to_string(expected_byte_size)); | ||
| } | ||
| } | ||
| } | ||
| return Error::Success; | ||
| } | ||
|
|
||
| Error | ||
| InferInput::PrepareForRequest() | ||
| { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't proto-library used for protobuf<->grpc? Why is it needed for HTTP client?
edit: guessing the requirement is here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any concerns with introducing the new protobuf dependency to the HTTP client, or any alternatives? CC @GuanLuo @tanmayv25