Skip to content

Commit 223ca21

Browse files
committed
cleanup request output
1 parent e672c28 commit 223ca21

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed

qa/L0_backend_python/model_control/model_control_test.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

2929
import os
30+
import subprocess
3031
import sys
3132

3233
sys.path.append("../../common")
@@ -83,5 +84,170 @@ def test_model_reload(self):
8384
self.assertFalse(client.is_model_ready(ensemble_model_name))
8485

8586

87+
class InputValidationTest(unittest.TestCase):
88+
"""
89+
Test input validation for user-provided inputs
90+
"""
91+
92+
def setUp(self):
93+
self._shm_leak_detector = shm_util.ShmLeakDetector()
94+
self._client = httpclient.InferenceServerClient(f"{_tritonserver_ipaddr}:8000")
95+
self._triton_host = _tritonserver_ipaddr
96+
self._triton_port = 8000
97+
98+
# Check if curl is available
99+
try:
100+
subprocess.run(["curl", "--version"], capture_output=True, check=True)
101+
except (subprocess.CalledProcessError, FileNotFoundError):
102+
self.skipTest("curl command not available - required for raw HTTP testing")
103+
104+
def _send_load_model_request(self, model_name):
105+
"""Send HTTP request to load model for testing input validation using curl"""
106+
payload = {
107+
"parameters": {
108+
"config": f'{{"name": "{model_name}", "backend": "python", "max_batch_size": 4}}',
109+
"file:/1/model.py": "print('Hello from Python Model')",
110+
}
111+
}
112+
113+
url = f"http://{self._triton_host}:{self._triton_port}/v2/repository/models/{model_name}/load"
114+
115+
# Convert payload to JSON string
116+
payload_json = json.dumps(payload)
117+
118+
try:
119+
# Use curl to send the request
120+
curl_cmd = [
121+
"curl",
122+
"-s",
123+
"-w",
124+
"\n%{http_code}",
125+
"-X",
126+
"POST",
127+
"-H",
128+
"Content-Type: application/json",
129+
"-d",
130+
payload_json,
131+
"--connect-timeout",
132+
"10",
133+
]
134+
135+
# Add the URL as a separate argument to avoid shell interpretation issues
136+
curl_cmd.append(url)
137+
138+
# Debug: print the exact URL being requested
139+
print(f"DEBUG: Curl URL: {url}")
140+
141+
result = subprocess.run(
142+
curl_cmd, capture_output=True, text=True, timeout=15
143+
)
144+
145+
# Parse curl output - last line is status code, rest is response body
146+
output_lines = (
147+
result.stdout.strip().split("\n") if result.stdout.strip() else []
148+
)
149+
if len(output_lines) >= 2:
150+
try:
151+
status_code = int(output_lines[-1])
152+
response_text = "\n".join(output_lines[:-1])
153+
except ValueError:
154+
status_code = 0
155+
response_text = result.stdout or result.stderr or "Invalid response"
156+
elif len(output_lines) == 1 and output_lines[0].isdigit():
157+
status_code = int(output_lines[0])
158+
response_text = result.stderr or "No response body"
159+
else:
160+
status_code = 0
161+
response_text = result.stdout or result.stderr or "No response"
162+
163+
# Return an object similar to requests.Response
164+
class CurlResponse:
165+
def __init__(self, status_code, text):
166+
self.status_code = status_code
167+
self.text = text
168+
self.content = text.encode()
169+
170+
return CurlResponse(status_code, response_text)
171+
172+
except (
173+
subprocess.TimeoutExpired,
174+
subprocess.CalledProcessError,
175+
ValueError,
176+
) as e:
177+
# Return a mock response for errors
178+
class ErrorResponse:
179+
def __init__(self, error_msg):
180+
self.status_code = 0
181+
self.text = f"Error: {error_msg}"
182+
self.content = self.text.encode()
183+
184+
return ErrorResponse(str(e))
185+
186+
def test_invalid_character_model_names(self):
187+
"""Test that model names with invalid characters are properly rejected"""
188+
189+
# Model names with various invalid characters that should be rejected
190+
invalid_model_names = [
191+
"model$(test)",
192+
"model\{test\}",
193+
"model`test`",
194+
"model;test",
195+
"model|test",
196+
"model&test",
197+
"model'test'",
198+
"model*test",
199+
"model!test",
200+
]
201+
202+
for invalid_name in invalid_model_names:
203+
with self.subTest(model_name=invalid_name):
204+
print(f"Testing invalid model name: {invalid_name}")
205+
206+
response = self._send_load_model_request(invalid_name)
207+
print(
208+
f"Response for '{invalid_name}': Status {response.status_code}, Text: {response.text[:200]}..."
209+
)
210+
211+
# Should not get a successful 200 response
212+
self.assertNotEqual(
213+
200,
214+
response.status_code,
215+
f"Invalid model name '{invalid_name}' should not get 200 OK response",
216+
)
217+
218+
self.assertIn(
219+
"Invalid stub name: contains invalid characters",
220+
response.text,
221+
f"invalid response for '{invalid_name}' should contain 'Invalid stub name: contains invalid characters'",
222+
)
223+
224+
def test_valid_model_names(self):
225+
"""Test that valid model names work"""
226+
227+
valid_model_names = [
228+
"TestModel123",
229+
"model-with-hyphens",
230+
"model_with_underscores",
231+
]
232+
233+
for valid_name in valid_model_names:
234+
with self.subTest(model_name=valid_name):
235+
print(f"Testing valid model name: {valid_name}")
236+
237+
response = self._send_load_model_request(valid_name)
238+
print(
239+
f"Response for valid '{valid_name}': Status {response.status_code}, Text: {response.text[:100]}..."
240+
)
241+
242+
# Valid names might still fail for other reasons (model doesn't exist, etc.)
243+
# but they should not be rejected due to character validation
244+
# We just check it's not a validation error
245+
self.assertNotIn(
246+
"Invalid stub name: contains invalid characters",
247+
response.text,
248+
f"valid response for '{valid_name}' should not contain 'Invalid stub name: contains invalid characters'",
249+
)
250+
251+
86252
if __name__ == "__main__":
87253
unittest.main()

0 commit comments

Comments
 (0)