diff --git a/qa/L0_backend_python/model_control/model_control_test.py b/qa/L0_backend_python/model_control/model_control_test.py index 9ccb73df4f..a0ad197bb8 100755 --- a/qa/L0_backend_python/model_control/model_control_test.py +++ b/qa/L0_backend_python/model_control/model_control_test.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -26,7 +26,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import base64 +import json import os +import subprocess import sys sys.path.append("../../common") @@ -83,5 +86,241 @@ def test_model_reload(self): self.assertFalse(client.is_model_ready(ensemble_model_name)) +class ModelIDValidationTest(unittest.TestCase): + """ + Test model ID validation for user-provided model names. + + Verifies that model names containing dangerous characters are properly rejected. + Uses raw HTTP requests via curl instead of the Triton client to test server-side + validation without the Triton client encoding special characters. + """ + + def setUp(self): + self._shm_leak_detector = shm_util.ShmLeakDetector() + self._client = httpclient.InferenceServerClient(f"{_tritonserver_ipaddr}:8000") + self._triton_host = _tritonserver_ipaddr + self._triton_port = 8000 + + # Check if curl is available + try: + subprocess.run(["curl", "--version"], capture_output=True, check=True) + except (subprocess.CalledProcessError, FileNotFoundError): + self.skipTest("curl command not available - required for raw HTTP testing") + + def _send_load_model_request(self, model_name): + """Send HTTP request to load model for testing input validation using curl""" + + # Create simple Triton Python model code + python_model_code = f"""import triton_python_backend_utils as pb_utils + +class TritonPythonModel: + def execute(self, requests): + print('Hello world from model {model_name}') + responses = [] + for request in requests: + # Simple identity function + input_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0") + out_tensor = pb_utils.Tensor("OUTPUT0", input_tensor.as_numpy()) + responses.append(pb_utils.InferenceResponse([out_tensor])) + return responses""" + + # Base64 encode the Python code (as required by Triton server) + python_code_b64 = base64.b64encode(python_model_code.encode("utf-8")).decode( + "ascii" + ) + + # Create simple config + config = { + "name": model_name, + "backend": "python", + "max_batch_size": 4, + "input": [{"name": "INPUT0", "data_type": "TYPE_FP32", "dims": [-1]}], + "output": [{"name": "OUTPUT0", "data_type": "TYPE_FP32", "dims": [-1]}], + } + + payload = { + "parameters": { + "config": json.dumps(config), + "file:/1/model.py": python_code_b64, + } + } + + url = f"http://{self._triton_host}:{self._triton_port}/v2/repository/models/{model_name}/load" + + # Convert payload to JSON string + payload_json = json.dumps(payload) + + try: + # Use curl to send the request + curl_cmd = [ + "curl", + "-s", + "-w", + "\n%{http_code}", # Write HTTP status code on separate line + "-X", + "POST", + "-H", + "Content-Type: application/json", + "-d", + payload_json, + url, + ] + + result = subprocess.run( + curl_cmd, capture_output=True, text=True, timeout=10 + ) + + # Parse curl output - last line is status code, rest is response body + output_lines = ( + result.stdout.strip().split("\n") if result.stdout.strip() else [] + ) + if len(output_lines) >= 2: + try: + status_code = int(output_lines[-1]) + response_text = "\n".join(output_lines[:-1]) + except ValueError: + status_code = 0 + response_text = result.stdout or result.stderr or "Invalid response" + elif len(output_lines) == 1 and output_lines[0].isdigit(): + status_code = int(output_lines[0]) + response_text = result.stderr or "No response body" + else: + status_code = 0 + response_text = result.stdout or result.stderr or "No response" + + # Return an object similar to requests.Response + class CurlResponse: + def __init__(self, status_code, text): + self.status_code = status_code + self.text = text + self.content = text.encode() + + return CurlResponse(status_code, response_text) + + except ( + subprocess.TimeoutExpired, + subprocess.CalledProcessError, + ValueError, + ) as e: + # Return a mock response for errors + class ErrorResponse: + def __init__(self, error_msg): + self.status_code = 0 + self.text = f"Error: {error_msg}" + self.content = self.text.encode() + + return ErrorResponse(str(e)) + + def test_invalid_character_model_names(self): + """Test that model names with invalid characters are properly rejected""" + + # Based on INVALID_CHARS = ";|&$`<>()[]{}\\\"'*?~#!" + invalid_model_names = [ + r"model;test", + r"model|test", + r"model&test", + r"model$test", + r"model`test`", + r"model", + r"model(test)", + # r"model[test]", # request fails to send unencoded + r"model{test}", + r"model\test", + r'model"test"', + r"model'test'", + r"model*test", + # r"model?test", # request fails to send unencoded + r"model~test", + # r"model#test", # request fails to send unencoded + r"model!test", + ] + + for invalid_name in invalid_model_names: + with self.subTest(model_name=invalid_name): + print(f"Testing invalid model name: {invalid_name}") + + response = self._send_load_model_request(invalid_name) + print( + f"Response for '{invalid_name}': Status {response.status_code}, Text: {response.text[:200]}..." + ) + + # Should not get a successful 200 response + self.assertNotEqual( + 200, + response.status_code, + f"Invalid model name '{invalid_name}' should not get 200 OK response", + ) + + # Special case for curly braces - they get stripped and cause load failures prior to the validation check + if "{" in invalid_name or "}" in invalid_name: + self.assertIn( + "failed to load", + response.text, + f"Model with curly braces '{invalid_name}' should fail to load", + ) + else: + # Normal case - should get character validation error + self.assertIn( + "Invalid stub name: contains invalid characters", + response.text, + f"invalid response for '{invalid_name}' should contain 'Invalid stub name: contains invalid characters'", + ) + + # Verify the model is not loaded/ready since it was rejected + try: + self.assertFalse( + self._client.is_model_ready(invalid_name), + f"Model '{invalid_name}' should not be ready after failed load attempt", + ) + except Exception as e: + # If checking model readiness fails, that's also acceptable since the model name is invalid + print( + f"Note: Could not check model readiness for '{invalid_name}': {e}" + ) + + def test_valid_model_names(self): + """Test that valid model names work""" + + valid_model_names = [ + "TestModel123", + "model-with-hyphens", + "model_with_underscores", + ] + + for valid_name in valid_model_names: + with self.subTest(model_name=valid_name): + print(f"Testing valid model name: {valid_name}") + + response = self._send_load_model_request(valid_name) + print( + f"Response for valid '{valid_name}': Status {response.status_code}, Text: {response.text[:100]}..." + ) + + # Valid model names should be accepted and load successfully + self.assertEqual( + 200, + response.status_code, + f"Valid model name '{valid_name}' should get 200 OK response, got {response.status_code}. Response: {response.text}", + ) + + # Should not contain validation error message + self.assertNotIn( + "Invalid stub name: contains invalid characters", + response.text, + f"Valid model name '{valid_name}' should not contain validation error message", + ) + + # Verify the model is actually loaded by checking if it's ready + try: + self.assertTrue( + self._client.is_model_ready(valid_name), + f"Model '{valid_name}' should be ready after successful load", + ) + # Clean up - unload the model after testing + self._client.unload_model(valid_name) + except Exception as e: + self.fail(f"Failed to check if model '{valid_name}' is ready: {e}") + + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_backend_python/model_control/test.sh b/qa/L0_backend_python/model_control/test.sh index e2c22f2685..f841222a1b 100755 --- a/qa/L0_backend_python/model_control/test.sh +++ b/qa/L0_backend_python/model_control/test.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -55,11 +55,24 @@ if [ $? -ne 0 ]; then echo -e "\n***\n*** model_control_test.py FAILED. \n***" RET=1 fi + +echo -e "\n***\n*** Running model ID validation test\n***" +SUBTEST="model_id_validation" +python3 -m pytest --junitxml=model_control.${SUBTEST}.report.xml model_control_test.py::ModelIDValidationTest >> ${CLIENT_LOG} 2>&1 + +if [ $? -ne 0 ]; then + echo -e "\n***\n*** model_id_validation_test.py FAILED. \n***" + RET=1 +fi + set -e kill_server if [ $RET -eq 1 ]; then + echo -e "\n***\n*** Server logs:\n***" + cat $SERVER_LOG + echo -e "\n***\n*** Client logs:\n***" cat $CLIENT_LOG echo -e "\n***\n*** model_control_test FAILED. \n***" else