-
Notifications
You must be signed in to change notification settings - Fork 246
feat: Client-side input shape/element validation #742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
3178f99
7210d00
9c2941b
b4c6a17
e5e6b7e
07059a6
8b699c0
60f3f52
2a5c507
6b56c3b
223b9d8
a9a2c1c
a584741
5889b8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| #!/usr/bin/env python3 | ||
|
|
||
| # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # Redistribution and use in source and binary forms, with or without | ||
| # modification, are permitted provided that the following conditions | ||
|
|
@@ -30,7 +30,7 @@ | |
| from tritonclient.grpc import service_pb2 | ||
| from tritonclient.utils import * | ||
|
|
||
| from ._utils import raise_error | ||
| from ._utils import get_data_type_byte_size, num_elements, raise_error | ||
|
|
||
|
|
||
| class InferInput: | ||
|
|
@@ -54,6 +54,7 @@ def __init__(self, name, shape, datatype): | |
| self._input.ClearField("shape") | ||
| self._input.shape.extend(shape) | ||
| self._input.datatype = datatype | ||
| self._data_shape = None | ||
| self._raw_content = None | ||
|
|
||
| def name(self): | ||
|
|
@@ -86,6 +87,36 @@ def shape(self): | |
| """ | ||
| return self._input.shape | ||
|
|
||
| def validate_data(self): | ||
| """Validate input has data and input shape matches input data. | ||
|
|
||
| Returns | ||
| ------- | ||
| None | ||
| """ | ||
| # Input must set only one of the following fields: '_raw_content', | ||
| # 'shared_memory_region' in '_input.parameters' | ||
| cnt = 0 | ||
| cnt += self._raw_content != None | ||
| cnt += "shared_memory_region" in self._input.parameters | ||
| if cnt != 1: | ||
| return | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we return an error when more that one fields are specified in the inputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This error was handled by the server. |
||
|
|
||
| # Skip due to trt reformat free tensor | ||
| if "shared_memory_region" in self._input.parameters: | ||
| return | ||
|
|
||
| # Not using shared memory | ||
| expected_num_elements = num_elements(self._input.shape) | ||
| data_num_elements = num_elements(self._data_shape) | ||
| if expected_num_elements != data_num_elements: | ||
| raise_error( | ||
| "input '{}' got unexpected elements count {}, expected {}".format( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also include the respective shapes in the error message as well? something like:
I think you are trying to keep supporting the case where a user might just want to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are shapes (8,2), (4,4) also valid? |
||
| self._input.name, data_num_elements, expected_num_elements | ||
| ) | ||
| ) | ||
| return | ||
|
|
||
| def set_shape(self, shape): | ||
| """Set the shape of input. | ||
|
|
||
|
|
@@ -171,6 +202,7 @@ def set_data_from_numpy(self, input_tensor): | |
| self._raw_content = b"" | ||
| else: | ||
| self._raw_content = input_tensor.tobytes() | ||
| self._data_shape = input_tensor.shape | ||
| return self | ||
|
|
||
| def set_shared_memory(self, region_name, byte_size, offset=0): | ||
|
|
@@ -193,6 +225,7 @@ def set_shared_memory(self, region_name, byte_size, offset=0): | |
| """ | ||
| self._input.ClearField("contents") | ||
| self._raw_content = None | ||
| self._data_shape = None | ||
|
|
||
| self._input.parameters["shared_memory_region"].string_param = region_name | ||
| self._input.parameters["shared_memory_byte_size"].int64_param = byte_size | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.