1
1
from typing import Any , List , Optional , Generator , Literal
2
2
import os
3
3
from urllib .parse import urlparse , urlunparse
4
+ import httpx
4
5
5
6
from llama_index .core .bridge .pydantic import Field , PrivateAttr , ConfigDict
6
7
from llama_index .core .callbacks import CBEventType , EventPayload
11
12
)
12
13
from llama_index .core .postprocessor .types import BaseNodePostprocessor
13
14
from llama_index .core .schema import MetadataMode , NodeWithScore , QueryBundle
14
- import requests
15
15
import warnings
16
16
from llama_index .core .base .llms .generic_utils import get_from_param_or_env
17
17
18
+ from .utils import (
19
+ RANKING_MODEL_TABLE ,
20
+ BASE_URL ,
21
+ DEFAULT_MODEL ,
22
+ Model ,
23
+ determine_model ,
24
+ )
18
25
from .utils import (
19
26
RANKING_MODEL_TABLE ,
20
27
BASE_URL ,
@@ -56,13 +63,15 @@ class NVIDIARerank(BaseNodePostprocessor):
56
63
_mode : str = PrivateAttr ("nvidia" )
57
64
_is_hosted : bool = PrivateAttr (True )
58
65
base_url : Optional [str ] = None
66
+ _http_client : Optional [httpx .Client ] = PrivateAttr (None )
59
67
60
68
def __init__ (
61
69
self ,
62
70
model : Optional [str ] = None ,
63
71
nvidia_api_key : Optional [str ] = None ,
64
72
api_key : Optional [str ] = None ,
65
73
base_url : Optional [str ] = os .getenv ("NVIDIA_BASE_URL" , BASE_URL ),
74
+ http_client : Optional [httpx .Client ] = None ,
66
75
** kwargs : Any ,
67
76
):
68
77
"""
@@ -75,6 +84,7 @@ def __init__(
75
84
nvidia_api_key (str, optional): The NVIDIA API key. Defaults to None.
76
85
api_key (str, optional): The API key. Defaults to None.
77
86
base_url (str, optional): The base URL of the on-premises NIM. Defaults to None.
87
+ http_client (httpx.Client, optional): Custom HTTP client for making requests.
78
88
truncate (str): "NONE", "END", truncate input text if it exceeds
79
89
the model's context length. Default is model dependent and
80
90
is likely to raise an error if an input is too long.
@@ -87,6 +97,8 @@ def __init__(
87
97
model = model or DEFAULT_MODEL
88
98
super ().__init__ (model = model , ** kwargs )
89
99
100
+ self ._is_hosted = base_url in KNOWN_URLS
101
+ self .base_url = base_url
90
102
self ._is_hosted = base_url in KNOWN_URLS
91
103
self .base_url = base_url
92
104
self ._api_key = get_from_param_or_env (
@@ -95,12 +107,11 @@ def __init__(
95
107
"NVIDIA_API_KEY" ,
96
108
"NO_API_KEY_PROVIDED" ,
97
109
)
98
-
99
110
if self ._is_hosted : # hosted on API Catalog (build.nvidia.com)
100
111
if (not self ._api_key ) or (self ._api_key == "NO_API_KEY_PROVIDED" ):
101
112
raise ValueError ("An API key is required for hosted NIM." )
102
113
else : # not hosted
103
- self .base_url = self ._validate_url (base_url )
114
+ self .base_url = self ._validate_url (self . base_url )
104
115
105
116
self .model = model
106
117
if not self .model :
@@ -110,10 +121,9 @@ def __init__(
110
121
self .__get_default_model ()
111
122
112
123
if not self .model .startswith ("nvdev/" ):
113
- # allow internal models
114
- # TODO: add test case for this
115
124
self ._validate_model (self .model ) ## validate model
116
- self .base_url = base_url
125
+
126
+ self ._http_client = http_client
117
127
118
128
def __get_default_model (self ):
119
129
"""Set default model."""
@@ -136,24 +146,30 @@ def __get_default_model(self):
136
146
else :
137
147
self .model = DEFAULT_MODEL
138
148
149
+ @property
150
+ def normalized_base_url (self ) -> str :
151
+ """Return the normalized base URL (without trailing slashes)."""
152
+ return self .base_url .rstrip ("/" )
153
+
154
+ def _get_headers (self , auth_required : bool = False ) -> dict :
155
+ """Return default headers for HTTP requests.
156
+
157
+ If auth_required is True or the client is hosted, includes an Authorization header.
158
+ """
159
+ headers = {"Accept" : "application/json" }
160
+ if auth_required or self ._is_hosted :
161
+ headers ["Authorization" ] = f"Bearer { self ._api_key } "
162
+ return headers
163
+
139
164
def _get_models (self ) -> List [Model ]:
140
- session = requests .Session ()
141
- self .base_url = self .base_url .rstrip ("/" ) + "/"
142
- if self ._is_hosted :
143
- _headers = {
144
- "Authorization" : f"Bearer { self ._api_key } " ,
145
- "Accept" : "application/json" ,
146
- }
147
- else :
148
- _headers = {
149
- "Accept" : "application/json" ,
150
- }
165
+ client = self .client
166
+ _headers = self ._get_headers (auth_required = self ._is_hosted )
151
167
url = (
152
168
"https://integrate.api.nvidia.com/v1/models"
153
169
if self ._is_hosted
154
- else self .base_url . rstrip ( "/" ) + "/models"
170
+ else self .normalized_base_url + "/models"
155
171
)
156
- response = session .get (url , headers = _headers )
172
+ response = client .get (url , headers = _headers )
157
173
response .raise_for_status ()
158
174
159
175
assert (
@@ -181,6 +197,18 @@ def _get_models(self) -> List[Model]:
181
197
]
182
198
else :
183
199
return RANKING_MODEL_TABLE
200
+ # TODO: hosted now has a model listing, need to merge known and listed models
201
+ # TODO: parse model config for local models
202
+ if not self ._is_hosted :
203
+ return [
204
+ Model (
205
+ id = model ["id" ],
206
+ base_model = getattr (model , "params" , {}).get ("root" , None ),
207
+ )
208
+ for model in response .json ()["data" ]
209
+ ]
210
+ else :
211
+ return RANKING_MODEL_TABLE
184
212
185
213
def _validate_url (self , base_url ):
186
214
"""
@@ -190,10 +218,37 @@ def _validate_url(self, base_url):
190
218
emit a warning. old documentation told users to pass in the full
191
219
inference url, which is incorrect and prevents model listing from working.
192
220
normalize base_url to end in /v1.
221
+ validate the base_url.
222
+ if the base_url is not a url, raise an error
223
+ if the base_url does not end in /v1, e.g. /embeddings
224
+ emit a warning. old documentation told users to pass in the full
225
+ inference url, which is incorrect and prevents model listing from working.
226
+ normalize base_url to end in /v1.
193
227
"""
194
228
if base_url is not None :
195
229
parsed = urlparse (base_url )
196
230
231
+ # Ensure scheme and netloc (domain name) are present
232
+ if not (parsed .scheme and parsed .netloc ):
233
+ expected_format = "Expected format is: http://host:port"
234
+ raise ValueError (
235
+ f"Invalid base_url format. { expected_format } Got: { base_url } "
236
+ )
237
+
238
+ normalized_path = parsed .path .rstrip ("/" )
239
+ if not normalized_path .endswith ("/v1" ):
240
+ warnings .warn (
241
+ f"{ base_url } does not end in /v1, you may "
242
+ "have inference and listing issues"
243
+ )
244
+ normalized_path += "/v1"
245
+
246
+ base_url = urlunparse (
247
+ (parsed .scheme , parsed .netloc , normalized_path , None , None , None )
248
+ )
249
+ if base_url is not None :
250
+ parsed = urlparse (base_url )
251
+
197
252
# Ensure scheme and netloc (domain name) are present
198
253
if not (parsed .scheme and parsed .netloc ):
199
254
expected_format = "Expected format is: http://host:port"
@@ -228,6 +283,15 @@ def _validate_model(self, model_name: str) -> None:
228
283
model = determine_model (model_name )
229
284
available_model_ids = [model .id for model in self .available_models ]
230
285
286
+ if not model :
287
+ if self ._is_hosted :
288
+ warnings .warn (f"Unable to determine validity of { model_name } " )
289
+ else :
290
+ if model_name not in available_model_ids :
291
+ raise ValueError (f"No locally hosted { model_name } was found." )
292
+ model = determine_model (model_name )
293
+ available_model_ids = [model .id for model in self .available_models ]
294
+
231
295
if not model :
232
296
if self ._is_hosted :
233
297
warnings .warn (f"Unable to determine validity of { model_name } " )
@@ -238,16 +302,29 @@ def _validate_model(self, model_name: str) -> None:
238
302
if model and model .endpoint :
239
303
self .base_url = model .endpoint
240
304
305
+ if model and model .endpoint :
306
+ self .base_url = model .endpoint
307
+
241
308
@property
242
309
def available_models (self ) -> List [Model ]:
243
310
"""Get available models."""
244
311
# all available models are in the map
245
312
ids = RANKING_MODEL_TABLE .keys ()
313
+ ids = RANKING_MODEL_TABLE .keys ()
246
314
if not self ._is_hosted :
247
315
return self ._get_models ()
248
316
else :
249
317
return [Model (id = id ) for id in ids ]
250
318
319
+ @property
320
+ def client (self ) -> httpx .Client :
321
+ """
322
+ Lazy initialization of the HTTP client.
323
+ """
324
+ if self ._http_client is None :
325
+ self ._http_client = httpx .Client ()
326
+ return self ._http_client
327
+
251
328
@classmethod
252
329
def class_name (cls ) -> str :
253
330
return "NVIDIARerank"
@@ -273,12 +350,8 @@ def _postprocess_nodes(
273
350
if len (nodes ) == 0 :
274
351
return []
275
352
276
- session = requests .Session ()
277
-
278
- _headers = {
279
- "Authorization" : f"Bearer { self ._api_key } " ,
280
- "Accept" : "application/json" ,
281
- }
353
+ client = self .client
354
+ _headers = self ._get_headers (auth_required = True )
282
355
283
356
# TODO: replace with itertools.batched in python 3.12
284
357
def batched (ls : list , size : int ) -> Generator [List [NodeWithScore ], None , None ]:
@@ -305,7 +378,7 @@ def batched(ls: list, size: int) -> Generator[List[NodeWithScore], None, None]:
305
378
for n in batch
306
379
],
307
380
}
308
- response = session .post (self .base_url , headers = _headers , json = payloads )
381
+ response = client .post (self .base_url , headers = _headers , json = payloads )
309
382
response .raise_for_status ()
310
383
# expected response format:
311
384
# {
0 commit comments