|
| 1 | +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# Redistribution and use in source and binary forms, with or without |
| 4 | +# modification, are permitted provided that the following conditions |
| 5 | +# are met: |
| 6 | +# * Redistributions of source code must retain the above copyright |
| 7 | +# notice, this list of conditions and the following disclaimer. |
| 8 | +# * Redistributions in binary form must reproduce the above copyright |
| 9 | +# notice, this list of conditions and the following disclaimer in the |
| 10 | +# documentation and/or other materials provided with the distribution. |
| 11 | +# * Neither the name of NVIDIA CORPORATION nor the names of its |
| 12 | +# contributors may be used to endorse or promote products derived |
| 13 | +# from this software without specific prior written permission. |
| 14 | +# |
| 15 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY |
| 16 | +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 17 | +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR |
| 18 | +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR |
| 19 | +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, |
| 20 | +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, |
| 21 | +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR |
| 22 | +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY |
| 23 | +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 24 | +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 25 | +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 26 | + |
| 27 | +from fastapi import Request |
| 28 | +from fastapi.responses import JSONResponse |
| 29 | +from starlette.middleware.base import BaseHTTPMiddleware |
| 30 | + |
| 31 | +# Mapping of API to their corresponding HTTP endpoints |
| 32 | +ENDPOINT_MAPPING = { |
| 33 | + "inference": ["POST /v1/chat/completions", "POST /v1/completions"], |
| 34 | + "model-repository": ["GET /v1/models"], |
| 35 | + "metrics": ["GET /metrics"], |
| 36 | + "health": ["GET /health/ready"], |
| 37 | +} |
| 38 | + |
| 39 | + |
| 40 | +class RestrictedFeatures: |
| 41 | + """ |
| 42 | + Manages API endpoint restrictions and their authentication requirements. |
| 43 | +
|
| 44 | + This class parses command-line arguments for restricted API configurations |
| 45 | + and provides methods to check if specific APIs are restricted |
| 46 | + and what authentication is required. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__(self, args: list[str]): |
| 50 | + """ |
| 51 | + Initialize the RestrictedFeatures with command-line arguments. |
| 52 | +
|
| 53 | + Args: |
| 54 | + args: List of --openai-restricted-api argument strings |
| 55 | + (e.g., [["inference", "infer-key", "infer-value"], |
| 56 | + ["model-repository", "model-key", "model-value"]]) |
| 57 | + """ |
| 58 | + self._restrictions = {} |
| 59 | + self.ParseRestrictedFeatureOption(args) |
| 60 | + |
| 61 | + def ParseRestrictedFeatureOption(self, args): |
| 62 | + """ |
| 63 | + Parse command-line arguments to extract API restrictions. |
| 64 | +
|
| 65 | + Args: |
| 66 | + args: List of restriction configuration strings |
| 67 | +
|
| 68 | + Raises: |
| 69 | + ValueError: If unknown API is specified or duplicate API configs are found |
| 70 | + """ |
| 71 | + for apis, key, value in args: |
| 72 | + api_list = apis.split(",") |
| 73 | + for api in api_list: |
| 74 | + # Validate that the API is valid |
| 75 | + if api not in ENDPOINT_MAPPING: |
| 76 | + raise ValueError( |
| 77 | + f"Unknown API '{api}'. Available APIs: {list(ENDPOINT_MAPPING.keys())}" |
| 78 | + ) |
| 79 | + |
| 80 | + # Check for duplicate APIs across different arguments |
| 81 | + if self.IsRestricted(api): |
| 82 | + raise ValueError( |
| 83 | + f"restricted api '{api}' can not be specified in multiple config groups" |
| 84 | + ) |
| 85 | + |
| 86 | + self.Insert(api, (key, value)) |
| 87 | + |
| 88 | + def RestrictionDict(self) -> dict[str, tuple[str, str]]: |
| 89 | + """ |
| 90 | + Get a copy of the restrictions dictionary. |
| 91 | +
|
| 92 | + Returns: |
| 93 | + dict: Copy of the restrictions mapping API names to (header_key, header_value) tuples |
| 94 | + """ |
| 95 | + return self._restrictions.copy() |
| 96 | + |
| 97 | + def Insert(self, api: str, restriction: tuple[str, str]): |
| 98 | + """ |
| 99 | + Add a restriction for a specific API. |
| 100 | +
|
| 101 | + Args: |
| 102 | + api: The API name (e.g., "inference", "model-repository") |
| 103 | + restriction: Tuple of (header_key, header_value) for authentication |
| 104 | + """ |
| 105 | + self._restrictions[api] = restriction |
| 106 | + |
| 107 | + def IsRestricted(self, api: str) -> bool: |
| 108 | + """ |
| 109 | + Check if a specific API is restricted. |
| 110 | +
|
| 111 | + Args: |
| 112 | + api: The API name to check |
| 113 | +
|
| 114 | + Returns: |
| 115 | + bool: True if the API is restricted, False otherwise |
| 116 | + """ |
| 117 | + return api in self._restrictions |
| 118 | + |
| 119 | + |
| 120 | +class APIRestrictionMiddleware(BaseHTTPMiddleware): |
| 121 | + """ |
| 122 | + Middleware to restrict API endpoint access based on allowed APIs configuration. |
| 123 | +
|
| 124 | + This middleware intercepts HTTP requests and checks if they match any restricted |
| 125 | + API endpoints. If a request matches a restricted endpoint, it validates the |
| 126 | + authentication headers before allowing the request to proceed. |
| 127 | +
|
| 128 | + Similar to Triton Server's endpoint access control feature. |
| 129 | + """ |
| 130 | + |
| 131 | + def __init__(self, app, restricted_apis: RestrictedFeatures): |
| 132 | + """ |
| 133 | + Initialize the API restriction middleware. |
| 134 | +
|
| 135 | + Args: |
| 136 | + app: The FastAPI application instance |
| 137 | + restricted_apis: RestrictedFeatures instance containing the restriction configuration |
| 138 | + """ |
| 139 | + super().__init__(app) |
| 140 | + self.restricted_apis = restricted_apis |
| 141 | + |
| 142 | + def _get_auth_header(self, request: Request) -> tuple[str, str] | None: |
| 143 | + request_method = request.method |
| 144 | + request_path = request.url.path |
| 145 | + |
| 146 | + # Check each restricted API to see if the request matches |
| 147 | + for ( |
| 148 | + restricted_api, |
| 149 | + auth_spec, |
| 150 | + ) in self.restricted_apis.RestrictionDict().items(): |
| 151 | + # Check each endpoint in the API |
| 152 | + for restricted_endpoint in ENDPOINT_MAPPING[restricted_api]: |
| 153 | + restricted_method, restricted_path = restricted_endpoint.split(" ") |
| 154 | + |
| 155 | + # Match both HTTP method and path prefix |
| 156 | + if request_method == restricted_method and request_path.startswith( |
| 157 | + restricted_path |
| 158 | + ): |
| 159 | + return auth_spec |
| 160 | + return None |
| 161 | + |
| 162 | + async def dispatch(self, request: Request, call_next): |
| 163 | + """ |
| 164 | + Main middleware dispatch method that processes each incoming request. |
| 165 | +
|
| 166 | + Args: |
| 167 | + request: The incoming HTTP request |
| 168 | + call_next: The next middleware/handler in the chain |
| 169 | +
|
| 170 | + Returns: |
| 171 | + Response: Either the next handler's response or a 401 authentication error |
| 172 | + """ |
| 173 | + # Check if the request matches any restricted patterns |
| 174 | + auth_header = self._get_auth_header(request) |
| 175 | + |
| 176 | + # If request not restricted, proceed with the request |
| 177 | + if not auth_header: |
| 178 | + return await call_next(request) |
| 179 | + |
| 180 | + # Check authentication for the matching restricted endpoint |
| 181 | + auth_result = self._check_authentication(request, auth_header) |
| 182 | + if auth_result["valid"]: |
| 183 | + # Authentication passed, allow request to proceed |
| 184 | + return await call_next(request) |
| 185 | + else: |
| 186 | + # Authentication failed, return 401 error |
| 187 | + return JSONResponse( |
| 188 | + status_code=401, |
| 189 | + content={ |
| 190 | + "error": { |
| 191 | + "message": auth_result["message"], |
| 192 | + "type": "authentication_error", |
| 193 | + "code": "invalid_auth", |
| 194 | + } |
| 195 | + }, |
| 196 | + ) |
| 197 | + |
| 198 | + def _check_authentication(self, request: Request, auth_header: tuple[str, str]): |
| 199 | + """ |
| 200 | + Check if the request contains valid authentication headers. |
| 201 | +
|
| 202 | + Args: |
| 203 | + request: The incoming HTTP request |
| 204 | + auth_header: Tuple of (expected_header_key, expected_header_value) |
| 205 | +
|
| 206 | + Returns: |
| 207 | + dict: {"valid": bool, "message": str} - Authentication result and error message if invalid |
| 208 | + """ |
| 209 | + expected_key, expected_value = auth_header |
| 210 | + |
| 211 | + # Get the actual header value from the request |
| 212 | + actual_value = request.headers.get(expected_key) |
| 213 | + |
| 214 | + # Validate the header value matches the expected value |
| 215 | + if not actual_value or actual_value != expected_value: |
| 216 | + return { |
| 217 | + "valid": False, |
| 218 | + "message": f"This API is restricted, expecting header '{expected_key}' with valid value", |
| 219 | + } |
| 220 | + |
| 221 | + return {"valid": True} |
0 commit comments