Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 240 additions & 1 deletion qa/L0_backend_python/model_control/model_control_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand All @@ -26,7 +26,10 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import base64
import json
import os
import subprocess
import sys

sys.path.append("../../common")
Expand Down Expand Up @@ -83,5 +86,241 @@ def test_model_reload(self):
self.assertFalse(client.is_model_ready(ensemble_model_name))


class ModelIDValidationTest(unittest.TestCase):
"""
Test model ID validation for user-provided model names.

Verifies that model names containing dangerous characters are properly rejected.
Uses raw HTTP requests via curl instead of the Triton client to test server-side
validation without the Triton client encoding special characters.
"""

def setUp(self):
self._shm_leak_detector = shm_util.ShmLeakDetector()
self._client = httpclient.InferenceServerClient(f"{_tritonserver_ipaddr}:8000")
self._triton_host = _tritonserver_ipaddr
self._triton_port = 8000

# Check if curl is available
try:
subprocess.run(["curl", "--version"], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
self.skipTest("curl command not available - required for raw HTTP testing")

def _send_load_model_request(self, model_name):
"""Send HTTP request to load model for testing input validation using curl"""

# Create simple Triton Python model code
python_model_code = f"""import triton_python_backend_utils as pb_utils

class TritonPythonModel:
def execute(self, requests):
print('Hello world from model {model_name}')
responses = []
for request in requests:
# Simple identity function
input_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0")
out_tensor = pb_utils.Tensor("OUTPUT0", input_tensor.as_numpy())
responses.append(pb_utils.InferenceResponse([out_tensor]))
return responses"""

# Base64 encode the Python code (as required by Triton server)
python_code_b64 = base64.b64encode(python_model_code.encode("utf-8")).decode(
"ascii"
)

# Create simple config
config = {
"name": model_name,
"backend": "python",
"max_batch_size": 4,
"input": [{"name": "INPUT0", "data_type": "TYPE_FP32", "dims": [-1]}],
"output": [{"name": "OUTPUT0", "data_type": "TYPE_FP32", "dims": [-1]}],
}

payload = {
"parameters": {
"config": json.dumps(config),
"file:/1/model.py": python_code_b64,
}
}

url = f"http://{self._triton_host}:{self._triton_port}/v2/repository/models/{model_name}/load"

# Convert payload to JSON string
payload_json = json.dumps(payload)

try:
# Use curl to send the request
curl_cmd = [
"curl",
"-s",
"-w",
"\n%{http_code}", # Write HTTP status code on separate line
"-X",
"POST",
"-H",
"Content-Type: application/json",
"-d",
payload_json,
url,
]

result = subprocess.run(
curl_cmd, capture_output=True, text=True, timeout=10
)

# Parse curl output - last line is status code, rest is response body
output_lines = (
result.stdout.strip().split("\n") if result.stdout.strip() else []
)
if len(output_lines) >= 2:
try:
status_code = int(output_lines[-1])
response_text = "\n".join(output_lines[:-1])
except ValueError:
status_code = 0
response_text = result.stdout or result.stderr or "Invalid response"
elif len(output_lines) == 1 and output_lines[0].isdigit():
status_code = int(output_lines[0])
response_text = result.stderr or "No response body"
else:
status_code = 0
response_text = result.stdout or result.stderr or "No response"

# Return an object similar to requests.Response
class CurlResponse:
def __init__(self, status_code, text):
self.status_code = status_code
self.text = text
self.content = text.encode()

return CurlResponse(status_code, response_text)

except (
subprocess.TimeoutExpired,
subprocess.CalledProcessError,
ValueError,
) as e:
# Return a mock response for errors
class ErrorResponse:
def __init__(self, error_msg):
self.status_code = 0
self.text = f"Error: {error_msg}"
self.content = self.text.encode()

return ErrorResponse(str(e))

def test_invalid_character_model_names(self):
"""Test that model names with invalid characters are properly rejected"""

# Based on INVALID_CHARS = ";|&$`<>()[]{}\\\"'*?~#!"
invalid_model_names = [
r"model;test",
r"model|test",
r"model&test",
r"model$test",
r"model`test`",
r"model<test>",
r"model(test)",
# r"model[test]", # request fails to send unencoded
r"model{test}",
r"model\test",
r'model"test"',
r"model'test'",
r"model*test",
# r"model?test", # request fails to send unencoded
r"model~test",
# r"model#test", # request fails to send unencoded
r"model!test",
]

for invalid_name in invalid_model_names:
with self.subTest(model_name=invalid_name):
print(f"Testing invalid model name: {invalid_name}")

response = self._send_load_model_request(invalid_name)
print(
f"Response for '{invalid_name}': Status {response.status_code}, Text: {response.text[:200]}..."
)

# Should not get a successful 200 response
self.assertNotEqual(
200,
response.status_code,
f"Invalid model name '{invalid_name}' should not get 200 OK response",
)

# Special case for curly braces - they get stripped and cause load failures prior to the validation check
if "{" in invalid_name or "}" in invalid_name:
self.assertIn(
"failed to load",
response.text,
f"Model with curly braces '{invalid_name}' should fail to load",
)
else:
# Normal case - should get character validation error
self.assertIn(
"Invalid stub name: contains invalid characters",
response.text,
f"invalid response for '{invalid_name}' should contain 'Invalid stub name: contains invalid characters'",
)

# Verify the model is not loaded/ready since it was rejected
try:
self.assertFalse(
self._client.is_model_ready(invalid_name),
f"Model '{invalid_name}' should not be ready after failed load attempt",
)
except Exception as e:
# If checking model readiness fails, that's also acceptable since the model name is invalid
print(
f"Note: Could not check model readiness for '{invalid_name}': {e}"
)

def test_valid_model_names(self):
"""Test that valid model names work"""

valid_model_names = [
"TestModel123",
"model-with-hyphens",
"model_with_underscores",
]

for valid_name in valid_model_names:
with self.subTest(model_name=valid_name):
print(f"Testing valid model name: {valid_name}")

response = self._send_load_model_request(valid_name)
print(
f"Response for valid '{valid_name}': Status {response.status_code}, Text: {response.text[:100]}..."
)

# Valid model names should be accepted and load successfully
self.assertEqual(
200,
response.status_code,
f"Valid model name '{valid_name}' should get 200 OK response, got {response.status_code}. Response: {response.text}",
)

# Should not contain validation error message
self.assertNotIn(
"Invalid stub name: contains invalid characters",
response.text,
f"Valid model name '{valid_name}' should not contain validation error message",
)

# Verify the model is actually loaded by checking if it's ready
try:
self.assertTrue(
self._client.is_model_ready(valid_name),
f"Model '{valid_name}' should be ready after successful load",
)
# Clean up - unload the model after testing
self._client.unload_model(valid_name)
except Exception as e:
self.fail(f"Failed to check if model '{valid_name}' is ready: {e}")


if __name__ == "__main__":
unittest.main()
15 changes: 14 additions & 1 deletion qa/L0_backend_python/model_control/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -55,11 +55,24 @@ if [ $? -ne 0 ]; then
echo -e "\n***\n*** model_control_test.py FAILED. \n***"
RET=1
fi

echo -e "\n***\n*** Running model ID validation test\n***"
SUBTEST="model_id_validation"
python3 -m pytest --junitxml=model_control.${SUBTEST}.report.xml model_control_test.py::ModelIDValidationTest >> ${CLIENT_LOG} 2>&1

if [ $? -ne 0 ]; then
echo -e "\n***\n*** model_id_validation_test.py FAILED. \n***"
RET=1
fi

set -e

kill_server

if [ $RET -eq 1 ]; then
echo -e "\n***\n*** Server logs:\n***"
cat $SERVER_LOG
echo -e "\n***\n*** Client logs:\n***"
cat $CLIENT_LOG
echo -e "\n***\n*** model_control_test FAILED. \n***"
else
Expand Down
Loading