10
10
from attrs import define
11
11
from tritonclient import grpc as grpcclient
12
12
from tritonclient import http as httpclient
13
+ from tritonclient .grpc import model_config_pb2
13
14
14
15
15
16
class TritonProtocol (Enum ):
@@ -31,13 +32,32 @@ def dict_to_attr(obj: dict[str, Any]) -> SimpleNamespace:
31
32
return json .loads (json .dumps (obj ), object_hook = lambda d : SimpleNamespace (** d ))
32
33
33
34
35
+ @define
36
+ class TritonModelInput :
37
+ """
38
+ Most of the fields are mapped to model_config_pb2.ModelInput(https://github.com/triton-inference-server/common/blob/a2de06f4c80b2c7b15469fa4d36e5f6445382bad/protobuf/model_config.proto#L317)
39
+
40
+ Commented fields are not used.
41
+ """
42
+
43
+ name : str
44
+ dtype : str # data_type mapping to https://github.com/triton-inference-server/client/blob/d257c0e5c3de6e15d6ef289ff2b96cecd0a69b5f/src/python/library/tritonclient/utils/__init__.py#L163-L190
45
+
46
+ format : int = 0
47
+ dims : list [int ] = [] # dims
48
+
49
+ # reshape: list[int] = []
50
+ # is_shape_tensor: bool = False
51
+ # allow_ragged_batch: bool = False
52
+ optional : bool = False
53
+
54
+
34
55
@define
35
56
class TritonModelSpec :
36
57
name : str
37
58
38
59
max_batch_size : int
39
- input_name : list [str ]
40
- input_dtype : list [str ]
60
+ model_input : list [TritonModelInput ]
41
61
42
62
output_name : list [str ]
43
63
@@ -91,7 +111,7 @@ def get_triton_client(
91
111
model_name : str ,
92
112
model_version : str ,
93
113
protocol : TritonProtocol ,
94
- ):
114
+ ) -> ( int , list [ TritonModelInput ], list [ str ]) :
95
115
"""
96
116
(required in)
97
117
:param triton_client:
@@ -107,23 +127,43 @@ def get_triton_client(
107
127
108
128
args = dict (model_name = model_name , model_version = model_version )
109
129
110
- model_metadata = triton_client .get_model_metadata (** args )
111
130
model_config = triton_client .get_model_config (** args )
112
131
if protocol is TritonProtocol .http :
113
- model_metadata = dict_to_attr (model_metadata )
114
132
model_config = dict_to_attr (model_config )
115
133
elif protocol is TritonProtocol .grpc :
116
134
model_config = model_config .config
117
135
118
- max_batch_size , input_name_list , output_name_list , dtype_list = parse_model (model_metadata , model_config )
136
+ max_batch_size , input_list , output_name_list = parse_model (model_config )
137
+
138
+ return max_batch_size , input_list , output_name_list
139
+
119
140
120
- return max_batch_size , input_name_list , output_name_list , dtype_list
141
+ def parse_model_input (
142
+ model_input : model_config_pb2 .ModelInput | SimpleNamespace ,
143
+ ) -> TritonModelInput :
144
+ """
145
+ https://github.com/triton-inference-server/common/blob/r23.08/protobuf/model_config.proto#L317-L412
146
+ """
147
+ RAW_DTYPE = model_input .data_type
148
+ if isinstance (model_input .data_type , int ):
149
+ RAW_DTYPE = model_config_pb2 .DataType .Name (RAW_DTYPE )
150
+ RAW_DTYPE = RAW_DTYPE .strip ("TYPE_" )
151
+
152
+ if RAW_DTYPE == "STRING" :
153
+ RAW_DTYPE = "BYTES" # https://github.com/triton-inference-server/client/blob/d257c0e5c3de6e15d6ef289ff2b96cecd0a69b5f/src/python/library/tritonclient/utils/__init__.py#L188-L189
154
+ return TritonModelInput (
155
+ name = model_input .name ,
156
+ dims = model_input .dims ,
157
+ dtype = RAW_DTYPE ,
158
+ optional = model_input .optional ,
159
+ )
121
160
122
161
123
- def parse_model (model_metadata , model_config ):
162
+ def parse_model (
163
+ model_config : model_config_pb2 .ModelConfig | SimpleNamespace ,
164
+ ) -> (int , list [TritonModelInput ], list [str ]):
124
165
return (
125
166
model_config .max_batch_size ,
126
- [input_metadata .name for input_metadata in model_metadata .inputs ],
127
- [output_metadata .name for output_metadata in model_metadata .outputs ],
128
- [input_metadata .datatype for input_metadata in model_metadata .inputs ],
167
+ [parse_model_input (model_config_input ) for model_config_input in model_config .input ],
168
+ [model_config_output .name for model_config_output in model_config .output ],
129
169
)
0 commit comments