Skip to content

Commit 2e632a9

Browse files
authored
feat: Static key authentication for OpenAI frontend (#8374)
1 parent e320c17 commit 2e632a9

File tree

8 files changed

+802
-12
lines changed

8 files changed

+802
-12
lines changed

python/openai/README.md

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,4 +637,55 @@ Example output:
637637
function name: get_n_day_weather_forecast
638638
function arguments: {"city": "Dallas", "state": "TX", "unit": "fahrenheit", num_days: 1}
639639
tool calling result: The weather in Dallas, Texas is 85 degrees fahrenheit in next 1 days.
640-
```
640+
```
641+
642+
## Limit Endpoint Access
643+
644+
The OpenAI-compatible server supports restricting access to specific API endpoints through authentication headers. This feature allows you to protect sensitive endpoints while keeping others publicly accessible.
645+
646+
### Configuration
647+
648+
Use the `--openai-restricted-api` command-line argument to configure endpoint restrictions:
649+
650+
```
651+
--openai-restricted-api <API_1>,<API_2>,... <restricted-key> <restricted-value>
652+
```
653+
654+
- **`API`**: A comma-separated list of APIs to be included in this group. Note that currently a given API is not allowed to be included in multiple groups. The following protocols / APIs are recognized:
655+
- **inference**: Chat completions and text completions endpoints
656+
- `POST /v1/chat/completions`
657+
- `POST /v1/completions`
658+
- **model-repository**: Model listing and information endpoints
659+
- `GET /v1/models`
660+
- `GET /v1/models/{model_name}`
661+
- **metrics**: Server metrics endpoint
662+
- `GET /metrics`
663+
- **health**: Health check endpoint
664+
- `GET /health/ready`
665+
666+
- **`restricted-key`**: The HTTP request header to be checked when a request is received.
667+
- **`restricted-value`**: The header value required to access the specified protocols.
668+
669+
### Examples
670+
671+
#### Restrict Inference API Endpoints Only
672+
```bash
673+
--openai-restricted-api "inference api-key my-secret-key"
674+
```
675+
676+
Clients must include the header:
677+
```bash
678+
curl -H "api-key: my-secret-key" \
679+
-X POST http://localhost:9000/v1/chat/completions \
680+
-d '{"model": "my-model", "messages": [{"role": "user", "content": "Hello"}]}'
681+
```
682+
683+
#### Restrict Multiple API Endpoints
684+
```bash
685+
# Different authentication for different APIs
686+
--openai-restricted-api "inference user-key user-secret" \
687+
--openai-restricted-api "model-repository admin-key admin-secret"
688+
689+
# Multiple APIs in single argument with shared authentication
690+
--openai-restricted-api "inference,model-repository shared-key shared-secret"
691+
```
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from fastapi import Request
28+
from fastapi.responses import JSONResponse
29+
from starlette.middleware.base import BaseHTTPMiddleware
30+
31+
# Mapping of API to their corresponding HTTP endpoints
32+
ENDPOINT_MAPPING = {
33+
"inference": ["POST /v1/chat/completions", "POST /v1/completions"],
34+
"model-repository": ["GET /v1/models"],
35+
"metrics": ["GET /metrics"],
36+
"health": ["GET /health/ready"],
37+
}
38+
39+
40+
class RestrictedFeatures:
41+
"""
42+
Manages API endpoint restrictions and their authentication requirements.
43+
44+
This class parses command-line arguments for restricted API configurations
45+
and provides methods to check if specific APIs are restricted
46+
and what authentication is required.
47+
"""
48+
49+
def __init__(self, args: list[str]):
50+
"""
51+
Initialize the RestrictedFeatures with command-line arguments.
52+
53+
Args:
54+
args: List of --openai-restricted-api argument strings
55+
(e.g., [["inference", "infer-key", "infer-value"],
56+
["model-repository", "model-key", "model-value"]])
57+
"""
58+
self._restrictions = {}
59+
self.ParseRestrictedFeatureOption(args)
60+
61+
def ParseRestrictedFeatureOption(self, args):
62+
"""
63+
Parse command-line arguments to extract API restrictions.
64+
65+
Args:
66+
args: List of restriction configuration strings
67+
68+
Raises:
69+
ValueError: If unknown API is specified or duplicate API configs are found
70+
"""
71+
for apis, key, value in args:
72+
api_list = apis.split(",")
73+
for api in api_list:
74+
# Validate that the API is valid
75+
if api not in ENDPOINT_MAPPING:
76+
raise ValueError(
77+
f"Unknown API '{api}'. Available APIs: {list(ENDPOINT_MAPPING.keys())}"
78+
)
79+
80+
# Check for duplicate APIs across different arguments
81+
if self.IsRestricted(api):
82+
raise ValueError(
83+
f"restricted api '{api}' can not be specified in multiple config groups"
84+
)
85+
86+
self.Insert(api, (key, value))
87+
88+
def RestrictionDict(self) -> dict[str, tuple[str, str]]:
89+
"""
90+
Get a copy of the restrictions dictionary.
91+
92+
Returns:
93+
dict: Copy of the restrictions mapping API names to (header_key, header_value) tuples
94+
"""
95+
return self._restrictions.copy()
96+
97+
def Insert(self, api: str, restriction: tuple[str, str]):
98+
"""
99+
Add a restriction for a specific API.
100+
101+
Args:
102+
api: The API name (e.g., "inference", "model-repository")
103+
restriction: Tuple of (header_key, header_value) for authentication
104+
"""
105+
self._restrictions[api] = restriction
106+
107+
def IsRestricted(self, api: str) -> bool:
108+
"""
109+
Check if a specific API is restricted.
110+
111+
Args:
112+
api: The API name to check
113+
114+
Returns:
115+
bool: True if the API is restricted, False otherwise
116+
"""
117+
return api in self._restrictions
118+
119+
120+
class APIRestrictionMiddleware(BaseHTTPMiddleware):
121+
"""
122+
Middleware to restrict API endpoint access based on allowed APIs configuration.
123+
124+
This middleware intercepts HTTP requests and checks if they match any restricted
125+
API endpoints. If a request matches a restricted endpoint, it validates the
126+
authentication headers before allowing the request to proceed.
127+
128+
Similar to Triton Server's endpoint access control feature.
129+
"""
130+
131+
def __init__(self, app, restricted_apis: RestrictedFeatures):
132+
"""
133+
Initialize the API restriction middleware.
134+
135+
Args:
136+
app: The FastAPI application instance
137+
restricted_apis: RestrictedFeatures instance containing the restriction configuration
138+
"""
139+
super().__init__(app)
140+
self.restricted_apis = restricted_apis
141+
142+
def _get_auth_header(self, request: Request) -> tuple[str, str] | None:
143+
request_method = request.method
144+
request_path = request.url.path
145+
146+
# Check each restricted API to see if the request matches
147+
for (
148+
restricted_api,
149+
auth_spec,
150+
) in self.restricted_apis.RestrictionDict().items():
151+
# Check each endpoint in the API
152+
for restricted_endpoint in ENDPOINT_MAPPING[restricted_api]:
153+
restricted_method, restricted_path = restricted_endpoint.split(" ")
154+
155+
# Match both HTTP method and path prefix
156+
if request_method == restricted_method and request_path.startswith(
157+
restricted_path
158+
):
159+
return auth_spec
160+
return None
161+
162+
async def dispatch(self, request: Request, call_next):
163+
"""
164+
Main middleware dispatch method that processes each incoming request.
165+
166+
Args:
167+
request: The incoming HTTP request
168+
call_next: The next middleware/handler in the chain
169+
170+
Returns:
171+
Response: Either the next handler's response or a 401 authentication error
172+
"""
173+
# Check if the request matches any restricted patterns
174+
auth_header = self._get_auth_header(request)
175+
176+
# If request not restricted, proceed with the request
177+
if not auth_header:
178+
return await call_next(request)
179+
180+
# Check authentication for the matching restricted endpoint
181+
auth_result = self._check_authentication(request, auth_header)
182+
if auth_result["valid"]:
183+
# Authentication passed, allow request to proceed
184+
return await call_next(request)
185+
else:
186+
# Authentication failed, return 401 error
187+
return JSONResponse(
188+
status_code=401,
189+
content={
190+
"error": {
191+
"message": auth_result["message"],
192+
"type": "authentication_error",
193+
"code": "invalid_auth",
194+
}
195+
},
196+
)
197+
198+
def _check_authentication(self, request: Request, auth_header: tuple[str, str]):
199+
"""
200+
Check if the request contains valid authentication headers.
201+
202+
Args:
203+
request: The incoming HTTP request
204+
auth_header: Tuple of (expected_header_key, expected_header_value)
205+
206+
Returns:
207+
dict: {"valid": bool, "message": str} - Authentication result and error message if invalid
208+
"""
209+
expected_key, expected_value = auth_header
210+
211+
# Get the actual header value from the request
212+
actual_value = request.headers.get(expected_key)
213+
214+
# Validate the header value matches the expected value
215+
if not actual_value or actual_value != expected_value:
216+
return {
217+
"valid": False,
218+
"message": f"This API is restricted, expecting header '{expected_key}' with valid value",
219+
}
220+
221+
return {"valid": True}

python/openai/openai_frontend/frontend/fastapi_frontend.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Redistribution and use in source and binary forms, with or without
44
# modification, are permitted provided that the following conditions
@@ -30,6 +30,10 @@
3030
from engine.triton_engine import TritonLLMEngine
3131
from fastapi import FastAPI
3232
from fastapi.middleware.cors import CORSMiddleware
33+
from frontend.fastapi.middleware.api_restriction import (
34+
APIRestrictionMiddleware,
35+
RestrictedFeatures,
36+
)
3337
from frontend.fastapi.routers import chat, completions, models, observability
3438
from frontend.frontend import OpenAIFrontend
3539

@@ -41,10 +45,17 @@ def __init__(
4145
host: str = "localhost",
4246
port: int = 8000,
4347
log_level: str = "info",
48+
restricted_apis: list = None,
4449
):
4550
self.host: str = host
4651
self.port: int = port
4752
self.log_level: str = log_level
53+
if restricted_apis:
54+
self.restricted_apis: RestrictedFeatures = RestrictedFeatures(
55+
restricted_apis
56+
)
57+
else:
58+
self.restricted_apis: RestrictedFeatures = None
4859
self.stopped: bool = False
4960

5061
self.app = self._create_app()
@@ -89,6 +100,8 @@ def _create_app(self):
89100

90101
# NOTE: For debugging purposes, should generally be restricted or removed
91102
self._add_cors_middleware(app)
103+
if self.restricted_apis != None:
104+
self._add_api_restriction_middleware(app)
92105

93106
return app
94107

@@ -107,3 +120,11 @@ def _add_cors_middleware(self, app: FastAPI):
107120
allow_methods=["*"],
108121
allow_headers=["*"],
109122
)
123+
124+
def _add_api_restriction_middleware(self, app: FastAPI):
125+
app.add_middleware(
126+
APIRestrictionMiddleware, restricted_apis=self.restricted_apis
127+
)
128+
print(
129+
f"[INFO] API restrictions enabled. Restricted API endpoints: {self.restricted_apis.RestrictionDict()}"
130+
)

0 commit comments

Comments
 (0)