|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 |
|
3 |
| -# Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
4 | 4 | #
|
5 | 5 | # Redistribution and use in source and binary forms, with or without
|
6 | 6 | # modification, are permitted provided that the following conditions
|
|
26 | 26 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
27 | 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
28 | 28 |
|
| 29 | +import base64 |
| 30 | +import json |
29 | 31 | import os
|
| 32 | +import subprocess |
30 | 33 | import sys
|
31 | 34 |
|
32 | 35 | sys.path.append("../../common")
|
@@ -83,5 +86,241 @@ def test_model_reload(self):
|
83 | 86 | self.assertFalse(client.is_model_ready(ensemble_model_name))
|
84 | 87 |
|
85 | 88 |
|
| 89 | +class ModelIDValidationTest(unittest.TestCase): |
| 90 | + """ |
| 91 | + Test model ID validation for user-provided model names. |
| 92 | +
|
| 93 | + Verifies that model names containing dangerous characters are properly rejected. |
| 94 | + Uses raw HTTP requests via curl instead of the Triton client to test server-side |
| 95 | + validation without the Triton client encoding special characters. |
| 96 | + """ |
| 97 | + |
| 98 | + def setUp(self): |
| 99 | + self._shm_leak_detector = shm_util.ShmLeakDetector() |
| 100 | + self._client = httpclient.InferenceServerClient(f"{_tritonserver_ipaddr}:8000") |
| 101 | + self._triton_host = _tritonserver_ipaddr |
| 102 | + self._triton_port = 8000 |
| 103 | + |
| 104 | + # Check if curl is available |
| 105 | + try: |
| 106 | + subprocess.run(["curl", "--version"], capture_output=True, check=True) |
| 107 | + except (subprocess.CalledProcessError, FileNotFoundError): |
| 108 | + self.skipTest("curl command not available - required for raw HTTP testing") |
| 109 | + |
| 110 | + def _send_load_model_request(self, model_name): |
| 111 | + """Send HTTP request to load model for testing input validation using curl""" |
| 112 | + |
| 113 | + # Create simple Triton Python model code |
| 114 | + python_model_code = f"""import triton_python_backend_utils as pb_utils |
| 115 | +
|
| 116 | +class TritonPythonModel: |
| 117 | + def execute(self, requests): |
| 118 | + print('Hello world from model {model_name}') |
| 119 | + responses = [] |
| 120 | + for request in requests: |
| 121 | + # Simple identity function |
| 122 | + input_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0") |
| 123 | + out_tensor = pb_utils.Tensor("OUTPUT0", input_tensor.as_numpy()) |
| 124 | + responses.append(pb_utils.InferenceResponse([out_tensor])) |
| 125 | + return responses""" |
| 126 | + |
| 127 | + # Base64 encode the Python code (as required by Triton server) |
| 128 | + python_code_b64 = base64.b64encode(python_model_code.encode("utf-8")).decode( |
| 129 | + "ascii" |
| 130 | + ) |
| 131 | + |
| 132 | + # Create simple config |
| 133 | + config = { |
| 134 | + "name": model_name, |
| 135 | + "backend": "python", |
| 136 | + "max_batch_size": 4, |
| 137 | + "input": [{"name": "INPUT0", "data_type": "TYPE_FP32", "dims": [-1]}], |
| 138 | + "output": [{"name": "OUTPUT0", "data_type": "TYPE_FP32", "dims": [-1]}], |
| 139 | + } |
| 140 | + |
| 141 | + payload = { |
| 142 | + "parameters": { |
| 143 | + "config": json.dumps(config), |
| 144 | + "file:/1/model.py": python_code_b64, |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + url = f"http://{self._triton_host}:{self._triton_port}/v2/repository/models/{model_name}/load" |
| 149 | + |
| 150 | + # Convert payload to JSON string |
| 151 | + payload_json = json.dumps(payload) |
| 152 | + |
| 153 | + try: |
| 154 | + # Use curl to send the request |
| 155 | + curl_cmd = [ |
| 156 | + "curl", |
| 157 | + "-s", |
| 158 | + "-w", |
| 159 | + "\n%{http_code}", # Write HTTP status code on separate line |
| 160 | + "-X", |
| 161 | + "POST", |
| 162 | + "-H", |
| 163 | + "Content-Type: application/json", |
| 164 | + "-d", |
| 165 | + payload_json, |
| 166 | + url, |
| 167 | + ] |
| 168 | + |
| 169 | + result = subprocess.run( |
| 170 | + curl_cmd, capture_output=True, text=True, timeout=10 |
| 171 | + ) |
| 172 | + |
| 173 | + # Parse curl output - last line is status code, rest is response body |
| 174 | + output_lines = ( |
| 175 | + result.stdout.strip().split("\n") if result.stdout.strip() else [] |
| 176 | + ) |
| 177 | + if len(output_lines) >= 2: |
| 178 | + try: |
| 179 | + status_code = int(output_lines[-1]) |
| 180 | + response_text = "\n".join(output_lines[:-1]) |
| 181 | + except ValueError: |
| 182 | + status_code = 0 |
| 183 | + response_text = result.stdout or result.stderr or "Invalid response" |
| 184 | + elif len(output_lines) == 1 and output_lines[0].isdigit(): |
| 185 | + status_code = int(output_lines[0]) |
| 186 | + response_text = result.stderr or "No response body" |
| 187 | + else: |
| 188 | + status_code = 0 |
| 189 | + response_text = result.stdout or result.stderr or "No response" |
| 190 | + |
| 191 | + # Return an object similar to requests.Response |
| 192 | + class CurlResponse: |
| 193 | + def __init__(self, status_code, text): |
| 194 | + self.status_code = status_code |
| 195 | + self.text = text |
| 196 | + self.content = text.encode() |
| 197 | + |
| 198 | + return CurlResponse(status_code, response_text) |
| 199 | + |
| 200 | + except ( |
| 201 | + subprocess.TimeoutExpired, |
| 202 | + subprocess.CalledProcessError, |
| 203 | + ValueError, |
| 204 | + ) as e: |
| 205 | + # Return a mock response for errors |
| 206 | + class ErrorResponse: |
| 207 | + def __init__(self, error_msg): |
| 208 | + self.status_code = 0 |
| 209 | + self.text = f"Error: {error_msg}" |
| 210 | + self.content = self.text.encode() |
| 211 | + |
| 212 | + return ErrorResponse(str(e)) |
| 213 | + |
| 214 | + def test_invalid_character_model_names(self): |
| 215 | + """Test that model names with invalid characters are properly rejected""" |
| 216 | + |
| 217 | + # Based on INVALID_CHARS = ";|&$`<>()[]{}\\\"'*?~#!" |
| 218 | + invalid_model_names = [ |
| 219 | + r"model;test", |
| 220 | + r"model|test", |
| 221 | + r"model&test", |
| 222 | + r"model$test", |
| 223 | + r"model`test`", |
| 224 | + r"model<test>", |
| 225 | + r"model(test)", |
| 226 | + # r"model[test]", # request fails to send unencoded |
| 227 | + r"model{test}", |
| 228 | + r"model\test", |
| 229 | + r'model"test"', |
| 230 | + r"model'test'", |
| 231 | + r"model*test", |
| 232 | + # r"model?test", # request fails to send unencoded |
| 233 | + r"model~test", |
| 234 | + # r"model#test", # request fails to send unencoded |
| 235 | + r"model!test", |
| 236 | + ] |
| 237 | + |
| 238 | + for invalid_name in invalid_model_names: |
| 239 | + with self.subTest(model_name=invalid_name): |
| 240 | + print(f"Testing invalid model name: {invalid_name}") |
| 241 | + |
| 242 | + response = self._send_load_model_request(invalid_name) |
| 243 | + print( |
| 244 | + f"Response for '{invalid_name}': Status {response.status_code}, Text: {response.text[:200]}..." |
| 245 | + ) |
| 246 | + |
| 247 | + # Should not get a successful 200 response |
| 248 | + self.assertNotEqual( |
| 249 | + 200, |
| 250 | + response.status_code, |
| 251 | + f"Invalid model name '{invalid_name}' should not get 200 OK response", |
| 252 | + ) |
| 253 | + |
| 254 | + # Special case for curly braces - they get stripped and cause load failures prior to the validation check |
| 255 | + if "{" in invalid_name or "}" in invalid_name: |
| 256 | + self.assertIn( |
| 257 | + "failed to load", |
| 258 | + response.text, |
| 259 | + f"Model with curly braces '{invalid_name}' should fail to load", |
| 260 | + ) |
| 261 | + else: |
| 262 | + # Normal case - should get character validation error |
| 263 | + self.assertIn( |
| 264 | + "Invalid stub name: contains invalid characters", |
| 265 | + response.text, |
| 266 | + f"invalid response for '{invalid_name}' should contain 'Invalid stub name: contains invalid characters'", |
| 267 | + ) |
| 268 | + |
| 269 | + # Verify the model is not loaded/ready since it was rejected |
| 270 | + try: |
| 271 | + self.assertFalse( |
| 272 | + self._client.is_model_ready(invalid_name), |
| 273 | + f"Model '{invalid_name}' should not be ready after failed load attempt", |
| 274 | + ) |
| 275 | + except Exception as e: |
| 276 | + # If checking model readiness fails, that's also acceptable since the model name is invalid |
| 277 | + print( |
| 278 | + f"Note: Could not check model readiness for '{invalid_name}': {e}" |
| 279 | + ) |
| 280 | + |
| 281 | + def test_valid_model_names(self): |
| 282 | + """Test that valid model names work""" |
| 283 | + |
| 284 | + valid_model_names = [ |
| 285 | + "TestModel123", |
| 286 | + "model-with-hyphens", |
| 287 | + "model_with_underscores", |
| 288 | + ] |
| 289 | + |
| 290 | + for valid_name in valid_model_names: |
| 291 | + with self.subTest(model_name=valid_name): |
| 292 | + print(f"Testing valid model name: {valid_name}") |
| 293 | + |
| 294 | + response = self._send_load_model_request(valid_name) |
| 295 | + print( |
| 296 | + f"Response for valid '{valid_name}': Status {response.status_code}, Text: {response.text[:100]}..." |
| 297 | + ) |
| 298 | + |
| 299 | + # Valid model names should be accepted and load successfully |
| 300 | + self.assertEqual( |
| 301 | + 200, |
| 302 | + response.status_code, |
| 303 | + f"Valid model name '{valid_name}' should get 200 OK response, got {response.status_code}. Response: {response.text}", |
| 304 | + ) |
| 305 | + |
| 306 | + # Should not contain validation error message |
| 307 | + self.assertNotIn( |
| 308 | + "Invalid stub name: contains invalid characters", |
| 309 | + response.text, |
| 310 | + f"Valid model name '{valid_name}' should not contain validation error message", |
| 311 | + ) |
| 312 | + |
| 313 | + # Verify the model is actually loaded by checking if it's ready |
| 314 | + try: |
| 315 | + self.assertTrue( |
| 316 | + self._client.is_model_ready(valid_name), |
| 317 | + f"Model '{valid_name}' should be ready after successful load", |
| 318 | + ) |
| 319 | + # Clean up - unload the model after testing |
| 320 | + self._client.unload_model(valid_name) |
| 321 | + except Exception as e: |
| 322 | + self.fail(f"Failed to check if model '{valid_name}' is ready: {e}") |
| 323 | + |
| 324 | + |
86 | 325 | if __name__ == "__main__":
|
87 | 326 | unittest.main()
|
0 commit comments