Skip to content

Commit a584741

Browse files
committed
Workaround with L0_trt_reformat_free by removing shm checks
1 parent a9a2c1c commit a584741

File tree

2 files changed

+22
-44
lines changed

2 files changed

+22
-44
lines changed

src/python/library/tritonclient/grpc/_infer_input.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -102,31 +102,19 @@ def validate_data(self):
102102
if cnt != 1:
103103
return
104104

105+
# Skip due to trt reformat free tensor
105106
if "shared_memory_region" in self._input.parameters:
106-
# Using shared memory
107-
if self._input.datatype != "BYTES":
108-
expected_byte_size = num_elements(
109-
self._input.shape
110-
) * get_data_type_byte_size(self._input.datatype)
111-
data_byte_size = self._input.parameters[
112-
"shared_memory_byte_size"
113-
].int64_param
114-
if data_byte_size != expected_byte_size:
115-
raise_error(
116-
"input '{}' got unexpected byte size {}, expected {}".format(
117-
self._input.name, data_byte_size, expected_byte_size
118-
)
119-
)
120-
else:
121-
# Not using shared memory
122-
expected_num_elements = num_elements(self._input.shape)
123-
data_num_elements = num_elements(self._data_shape)
124-
if expected_num_elements != data_num_elements:
125-
raise_error(
126-
"input '{}' got unexpected elements count {}, expected {}".format(
127-
self._input.name, data_num_elements, expected_num_elements
128-
)
107+
return
108+
109+
# Not using shared memory
110+
expected_num_elements = num_elements(self._input.shape)
111+
data_num_elements = num_elements(self._data_shape)
112+
if expected_num_elements != data_num_elements:
113+
raise_error(
114+
"input '{}' got unexpected elements count {}, expected {}".format(
115+
self._input.name, data_num_elements, expected_num_elements
129116
)
117+
)
130118
return
131119

132120
def set_shape(self, shape):

src/python/library/tritonclient/http/_infer_input.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -106,29 +106,19 @@ def validate_data(self):
106106
if cnt != 1:
107107
return
108108

109+
# Skip due to trt reformat free tensor
109110
if "shared_memory_region" in self._parameters:
110-
# Using shared memory
111-
if self._datatype != "BYTES":
112-
expected_byte_size = num_elements(
113-
self._shape
114-
) * get_data_type_byte_size(self._datatype)
115-
data_byte_size = self._parameters["shared_memory_byte_size"]
116-
if data_byte_size != expected_byte_size:
117-
raise_error(
118-
"input '{}' got unexpected byte size {}, expected {}".format(
119-
self._name, data_byte_size, expected_byte_size
120-
)
121-
)
122-
else:
123-
# Not using shared memory
124-
expected_num_elements = num_elements(self._shape)
125-
data_num_elements = num_elements(self._data_shape)
126-
if expected_num_elements != data_num_elements:
127-
raise_error(
128-
"input '{}' got unexpected elements count {}, expected {}".format(
129-
self._name, data_num_elements, expected_num_elements
130-
)
111+
return
112+
113+
# Not using shared memory
114+
expected_num_elements = num_elements(self._shape)
115+
data_num_elements = num_elements(self._data_shape)
116+
if expected_num_elements != data_num_elements:
117+
raise_error(
118+
"input '{}' got unexpected elements count {}, expected {}".format(
119+
self._name, data_num_elements, expected_num_elements
131120
)
121+
)
132122
return
133123

134124
def set_shape(self, shape):

0 commit comments

Comments
 (0)