Skip to content

Commit 9f029a5

Browse files
authored
Check if other models on server (#756)
* Warn if non-profiled model is loaded on remote server * HTTP and GRPC have different API calls
1 parent d708869 commit 9f029a5

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

model_analyzer/analyzer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def profile(
116116
self._create_metrics_manager(client, gpus)
117117
self._create_model_manager(client, gpus)
118118

119+
if self._config.triton_launch_mode == "remote":
120+
self._warn_if_other_models_loaded_on_remote_server(client)
121+
119122
if self._config.model_repository:
120123
self._get_server_only_metrics(client, gpus)
121124
self._profile_models()
@@ -395,3 +398,15 @@ def _create_report_config(self, args: list) -> ConfigCommandReport:
395398
cli.add_subcommand(cmd="report", help="", config=config)
396399
cli.parse(args)
397400
return config
401+
402+
def _warn_if_other_models_loaded_on_remote_server(self, client):
403+
repository_index = client.get_model_repository_index()
404+
profile_model_names = [pm.model_name() for pm in self._config.profile_models]
405+
406+
for model in repository_index:
407+
if model["name"] not in profile_model_names:
408+
model_name = model["name"]
409+
logger.warning(
410+
f"A model not being profiled ({model_name}) is loaded on the remote Tritonserver. "
411+
"This could impact the profile results."
412+
)

model_analyzer/triton/client/grpc_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,9 @@ def get_model_config(self, model_name, num_retries):
7878
self.wait_for_model_ready(model_name, num_retries)
7979
model_config_dict = self._client.get_model_config(model_name, as_json=True)
8080
return model_config_dict["config"]
81+
82+
def get_model_repository_index(self):
83+
"""
84+
Returns the JSON dict holding the model repository index.
85+
"""
86+
return self._client.get_model_repository_index(as_json=True)["models"]

model_analyzer/triton/client/http_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,9 @@ def __init__(self, server_url, ssl_options={}):
9393
ssl_context_factory=ssl_context_factory,
9494
insecure=insecure,
9595
)
96+
97+
def get_model_repository_index(self):
98+
"""
99+
Returns the JSON dict holding the model repository index.
100+
"""
101+
return self._client.get_model_repository_index()

0 commit comments

Comments
 (0)