Skip to content

Commit a5805b5

Browse files
authored
fix: Add testing for explicit model load (#8276) (#8343)
1 parent 6c1e449 commit a5805b5

File tree

2 files changed

+254
-2
lines changed

2 files changed

+254
-2
lines changed

qa/L0_backend_python/model_control/model_control_test.py

Lines changed: 240 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
44
#
55
# Redistribution and use in source and binary forms, with or without
66
# modification, are permitted provided that the following conditions
@@ -26,7 +26,10 @@
2626
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29+
import base64
30+
import json
2931
import os
32+
import subprocess
3033
import sys
3134

3235
sys.path.append("../../common")
@@ -83,5 +86,241 @@ def test_model_reload(self):
8386
self.assertFalse(client.is_model_ready(ensemble_model_name))
8487

8588

89+
class ModelIDValidationTest(unittest.TestCase):
90+
"""
91+
Test model ID validation for user-provided model names.
92+
93+
Verifies that model names containing dangerous characters are properly rejected.
94+
Uses raw HTTP requests via curl instead of the Triton client to test server-side
95+
validation without the Triton client encoding special characters.
96+
"""
97+
98+
def setUp(self):
99+
self._shm_leak_detector = shm_util.ShmLeakDetector()
100+
self._client = httpclient.InferenceServerClient(f"{_tritonserver_ipaddr}:8000")
101+
self._triton_host = _tritonserver_ipaddr
102+
self._triton_port = 8000
103+
104+
# Check if curl is available
105+
try:
106+
subprocess.run(["curl", "--version"], capture_output=True, check=True)
107+
except (subprocess.CalledProcessError, FileNotFoundError):
108+
self.skipTest("curl command not available - required for raw HTTP testing")
109+
110+
def _send_load_model_request(self, model_name):
111+
"""Send HTTP request to load model for testing input validation using curl"""
112+
113+
# Create simple Triton Python model code
114+
python_model_code = f"""import triton_python_backend_utils as pb_utils
115+
116+
class TritonPythonModel:
117+
def execute(self, requests):
118+
print('Hello world from model {model_name}')
119+
responses = []
120+
for request in requests:
121+
# Simple identity function
122+
input_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0")
123+
out_tensor = pb_utils.Tensor("OUTPUT0", input_tensor.as_numpy())
124+
responses.append(pb_utils.InferenceResponse([out_tensor]))
125+
return responses"""
126+
127+
# Base64 encode the Python code (as required by Triton server)
128+
python_code_b64 = base64.b64encode(python_model_code.encode("utf-8")).decode(
129+
"ascii"
130+
)
131+
132+
# Create simple config
133+
config = {
134+
"name": model_name,
135+
"backend": "python",
136+
"max_batch_size": 4,
137+
"input": [{"name": "INPUT0", "data_type": "TYPE_FP32", "dims": [-1]}],
138+
"output": [{"name": "OUTPUT0", "data_type": "TYPE_FP32", "dims": [-1]}],
139+
}
140+
141+
payload = {
142+
"parameters": {
143+
"config": json.dumps(config),
144+
"file:/1/model.py": python_code_b64,
145+
}
146+
}
147+
148+
url = f"http://{self._triton_host}:{self._triton_port}/v2/repository/models/{model_name}/load"
149+
150+
# Convert payload to JSON string
151+
payload_json = json.dumps(payload)
152+
153+
try:
154+
# Use curl to send the request
155+
curl_cmd = [
156+
"curl",
157+
"-s",
158+
"-w",
159+
"\n%{http_code}", # Write HTTP status code on separate line
160+
"-X",
161+
"POST",
162+
"-H",
163+
"Content-Type: application/json",
164+
"-d",
165+
payload_json,
166+
url,
167+
]
168+
169+
result = subprocess.run(
170+
curl_cmd, capture_output=True, text=True, timeout=10
171+
)
172+
173+
# Parse curl output - last line is status code, rest is response body
174+
output_lines = (
175+
result.stdout.strip().split("\n") if result.stdout.strip() else []
176+
)
177+
if len(output_lines) >= 2:
178+
try:
179+
status_code = int(output_lines[-1])
180+
response_text = "\n".join(output_lines[:-1])
181+
except ValueError:
182+
status_code = 0
183+
response_text = result.stdout or result.stderr or "Invalid response"
184+
elif len(output_lines) == 1 and output_lines[0].isdigit():
185+
status_code = int(output_lines[0])
186+
response_text = result.stderr or "No response body"
187+
else:
188+
status_code = 0
189+
response_text = result.stdout or result.stderr or "No response"
190+
191+
# Return an object similar to requests.Response
192+
class CurlResponse:
193+
def __init__(self, status_code, text):
194+
self.status_code = status_code
195+
self.text = text
196+
self.content = text.encode()
197+
198+
return CurlResponse(status_code, response_text)
199+
200+
except (
201+
subprocess.TimeoutExpired,
202+
subprocess.CalledProcessError,
203+
ValueError,
204+
) as e:
205+
# Return a mock response for errors
206+
class ErrorResponse:
207+
def __init__(self, error_msg):
208+
self.status_code = 0
209+
self.text = f"Error: {error_msg}"
210+
self.content = self.text.encode()
211+
212+
return ErrorResponse(str(e))
213+
214+
def test_invalid_character_model_names(self):
215+
"""Test that model names with invalid characters are properly rejected"""
216+
217+
# Based on INVALID_CHARS = ";|&$`<>()[]{}\\\"'*?~#!"
218+
invalid_model_names = [
219+
r"model;test",
220+
r"model|test",
221+
r"model&test",
222+
r"model$test",
223+
r"model`test`",
224+
r"model<test>",
225+
r"model(test)",
226+
# r"model[test]", # request fails to send unencoded
227+
r"model{test}",
228+
r"model\test",
229+
r'model"test"',
230+
r"model'test'",
231+
r"model*test",
232+
# r"model?test", # request fails to send unencoded
233+
r"model~test",
234+
# r"model#test", # request fails to send unencoded
235+
r"model!test",
236+
]
237+
238+
for invalid_name in invalid_model_names:
239+
with self.subTest(model_name=invalid_name):
240+
print(f"Testing invalid model name: {invalid_name}")
241+
242+
response = self._send_load_model_request(invalid_name)
243+
print(
244+
f"Response for '{invalid_name}': Status {response.status_code}, Text: {response.text[:200]}..."
245+
)
246+
247+
# Should not get a successful 200 response
248+
self.assertNotEqual(
249+
200,
250+
response.status_code,
251+
f"Invalid model name '{invalid_name}' should not get 200 OK response",
252+
)
253+
254+
# Special case for curly braces - they get stripped and cause load failures prior to the validation check
255+
if "{" in invalid_name or "}" in invalid_name:
256+
self.assertIn(
257+
"failed to load",
258+
response.text,
259+
f"Model with curly braces '{invalid_name}' should fail to load",
260+
)
261+
else:
262+
# Normal case - should get character validation error
263+
self.assertIn(
264+
"Invalid stub name: contains invalid characters",
265+
response.text,
266+
f"invalid response for '{invalid_name}' should contain 'Invalid stub name: contains invalid characters'",
267+
)
268+
269+
# Verify the model is not loaded/ready since it was rejected
270+
try:
271+
self.assertFalse(
272+
self._client.is_model_ready(invalid_name),
273+
f"Model '{invalid_name}' should not be ready after failed load attempt",
274+
)
275+
except Exception as e:
276+
# If checking model readiness fails, that's also acceptable since the model name is invalid
277+
print(
278+
f"Note: Could not check model readiness for '{invalid_name}': {e}"
279+
)
280+
281+
def test_valid_model_names(self):
282+
"""Test that valid model names work"""
283+
284+
valid_model_names = [
285+
"TestModel123",
286+
"model-with-hyphens",
287+
"model_with_underscores",
288+
]
289+
290+
for valid_name in valid_model_names:
291+
with self.subTest(model_name=valid_name):
292+
print(f"Testing valid model name: {valid_name}")
293+
294+
response = self._send_load_model_request(valid_name)
295+
print(
296+
f"Response for valid '{valid_name}': Status {response.status_code}, Text: {response.text[:100]}..."
297+
)
298+
299+
# Valid model names should be accepted and load successfully
300+
self.assertEqual(
301+
200,
302+
response.status_code,
303+
f"Valid model name '{valid_name}' should get 200 OK response, got {response.status_code}. Response: {response.text}",
304+
)
305+
306+
# Should not contain validation error message
307+
self.assertNotIn(
308+
"Invalid stub name: contains invalid characters",
309+
response.text,
310+
f"Valid model name '{valid_name}' should not contain validation error message",
311+
)
312+
313+
# Verify the model is actually loaded by checking if it's ready
314+
try:
315+
self.assertTrue(
316+
self._client.is_model_ready(valid_name),
317+
f"Model '{valid_name}' should be ready after successful load",
318+
)
319+
# Clean up - unload the model after testing
320+
self._client.unload_model(valid_name)
321+
except Exception as e:
322+
self.fail(f"Failed to check if model '{valid_name}' is ready: {e}")
323+
324+
86325
if __name__ == "__main__":
87326
unittest.main()

qa/L0_backend_python/model_control/test.sh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -55,11 +55,24 @@ if [ $? -ne 0 ]; then
5555
echo -e "\n***\n*** model_control_test.py FAILED. \n***"
5656
RET=1
5757
fi
58+
59+
echo -e "\n***\n*** Running model ID validation test\n***"
60+
SUBTEST="model_id_validation"
61+
python3 -m pytest --junitxml=model_control.${SUBTEST}.report.xml model_control_test.py::ModelIDValidationTest >> ${CLIENT_LOG} 2>&1
62+
63+
if [ $? -ne 0 ]; then
64+
echo -e "\n***\n*** model_id_validation_test.py FAILED. \n***"
65+
RET=1
66+
fi
67+
5868
set -e
5969

6070
kill_server
6171

6272
if [ $RET -eq 1 ]; then
73+
echo -e "\n***\n*** Server logs:\n***"
74+
cat $SERVER_LOG
75+
echo -e "\n***\n*** Client logs:\n***"
6376
cat $CLIENT_LOG
6477
echo -e "\n***\n*** model_control_test FAILED. \n***"
6578
else

0 commit comments

Comments
 (0)