Skip to content

Commit f837c47

Browse files
authored
NVIDIARerank add http_client parameter to pass custom clients (#17832)
1 parent a59dce8 commit f837c47

File tree

8 files changed

+190
-89
lines changed

8 files changed

+190
-89
lines changed

llama-index-integrations/postprocessor/llama-index-postprocessor-nvidia-rerank/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,20 @@ nodes = parser.get_nodes_from_documents(documents)
103103
# rerank
104104
rerank.postprocess_nodes(nodes, query_str=query)
105105
```
106+
107+
### Custom HTTP Client
108+
109+
If you need more control over HTTP settings (e.g., timeouts, proxies, retries), you can pass your own `httpx.Client` instance to the `NVIDIARerank` initializer:
110+
111+
```python
112+
import httpx
113+
from llama_index.postprocessor.nvidia_rerank import NVIDIARerank
114+
115+
# Create a custom httpx client with a 10-second timeout
116+
custom_client = httpx.Client(timeout=10.0)
117+
118+
# Pass the custom client to the reranker
119+
rerank = NVIDIARerank(
120+
base_url="http://localhost:1976/v1", http_client=custom_client
121+
)
122+
```

llama-index-integrations/postprocessor/llama-index-postprocessor-nvidia-rerank/llama_index/postprocessor/nvidia_rerank/base.py

Lines changed: 99 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, List, Optional, Generator, Literal
22
import os
33
from urllib.parse import urlparse, urlunparse
4+
import httpx
45

56
from llama_index.core.bridge.pydantic import Field, PrivateAttr, ConfigDict
67
from llama_index.core.callbacks import CBEventType, EventPayload
@@ -11,10 +12,16 @@
1112
)
1213
from llama_index.core.postprocessor.types import BaseNodePostprocessor
1314
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
14-
import requests
1515
import warnings
1616
from llama_index.core.base.llms.generic_utils import get_from_param_or_env
1717

18+
from .utils import (
19+
RANKING_MODEL_TABLE,
20+
BASE_URL,
21+
DEFAULT_MODEL,
22+
Model,
23+
determine_model,
24+
)
1825
from .utils import (
1926
RANKING_MODEL_TABLE,
2027
BASE_URL,
@@ -56,13 +63,15 @@ class NVIDIARerank(BaseNodePostprocessor):
5663
_mode: str = PrivateAttr("nvidia")
5764
_is_hosted: bool = PrivateAttr(True)
5865
base_url: Optional[str] = None
66+
_http_client: Optional[httpx.Client] = PrivateAttr(None)
5967

6068
def __init__(
6169
self,
6270
model: Optional[str] = None,
6371
nvidia_api_key: Optional[str] = None,
6472
api_key: Optional[str] = None,
6573
base_url: Optional[str] = os.getenv("NVIDIA_BASE_URL", BASE_URL),
74+
http_client: Optional[httpx.Client] = None,
6675
**kwargs: Any,
6776
):
6877
"""
@@ -75,6 +84,7 @@ def __init__(
7584
nvidia_api_key (str, optional): The NVIDIA API key. Defaults to None.
7685
api_key (str, optional): The API key. Defaults to None.
7786
base_url (str, optional): The base URL of the on-premises NIM. Defaults to None.
87+
http_client (httpx.Client, optional): Custom HTTP client for making requests.
7888
truncate (str): "NONE", "END", truncate input text if it exceeds
7989
the model's context length. Default is model dependent and
8090
is likely to raise an error if an input is too long.
@@ -87,6 +97,8 @@ def __init__(
8797
model = model or DEFAULT_MODEL
8898
super().__init__(model=model, **kwargs)
8999

100+
self._is_hosted = base_url in KNOWN_URLS
101+
self.base_url = base_url
90102
self._is_hosted = base_url in KNOWN_URLS
91103
self.base_url = base_url
92104
self._api_key = get_from_param_or_env(
@@ -95,12 +107,11 @@ def __init__(
95107
"NVIDIA_API_KEY",
96108
"NO_API_KEY_PROVIDED",
97109
)
98-
99110
if self._is_hosted: # hosted on API Catalog (build.nvidia.com)
100111
if (not self._api_key) or (self._api_key == "NO_API_KEY_PROVIDED"):
101112
raise ValueError("An API key is required for hosted NIM.")
102113
else: # not hosted
103-
self.base_url = self._validate_url(base_url)
114+
self.base_url = self._validate_url(self.base_url)
104115

105116
self.model = model
106117
if not self.model:
@@ -110,10 +121,9 @@ def __init__(
110121
self.__get_default_model()
111122

112123
if not self.model.startswith("nvdev/"):
113-
# allow internal models
114-
# TODO: add test case for this
115124
self._validate_model(self.model) ## validate model
116-
self.base_url = base_url
125+
126+
self._http_client = http_client
117127

118128
def __get_default_model(self):
119129
"""Set default model."""
@@ -136,24 +146,30 @@ def __get_default_model(self):
136146
else:
137147
self.model = DEFAULT_MODEL
138148

149+
@property
150+
def normalized_base_url(self) -> str:
151+
"""Return the normalized base URL (without trailing slashes)."""
152+
return self.base_url.rstrip("/")
153+
154+
def _get_headers(self, auth_required: bool = False) -> dict:
155+
"""Return default headers for HTTP requests.
156+
157+
If auth_required is True or the client is hosted, includes an Authorization header.
158+
"""
159+
headers = {"Accept": "application/json"}
160+
if auth_required or self._is_hosted:
161+
headers["Authorization"] = f"Bearer {self._api_key}"
162+
return headers
163+
139164
def _get_models(self) -> List[Model]:
140-
session = requests.Session()
141-
self.base_url = self.base_url.rstrip("/") + "/"
142-
if self._is_hosted:
143-
_headers = {
144-
"Authorization": f"Bearer {self._api_key}",
145-
"Accept": "application/json",
146-
}
147-
else:
148-
_headers = {
149-
"Accept": "application/json",
150-
}
165+
client = self.client
166+
_headers = self._get_headers(auth_required=self._is_hosted)
151167
url = (
152168
"https://integrate.api.nvidia.com/v1/models"
153169
if self._is_hosted
154-
else self.base_url.rstrip("/") + "/models"
170+
else self.normalized_base_url + "/models"
155171
)
156-
response = session.get(url, headers=_headers)
172+
response = client.get(url, headers=_headers)
157173
response.raise_for_status()
158174

159175
assert (
@@ -181,6 +197,18 @@ def _get_models(self) -> List[Model]:
181197
]
182198
else:
183199
return RANKING_MODEL_TABLE
200+
# TODO: hosted now has a model listing, need to merge known and listed models
201+
# TODO: parse model config for local models
202+
if not self._is_hosted:
203+
return [
204+
Model(
205+
id=model["id"],
206+
base_model=getattr(model, "params", {}).get("root", None),
207+
)
208+
for model in response.json()["data"]
209+
]
210+
else:
211+
return RANKING_MODEL_TABLE
184212

185213
def _validate_url(self, base_url):
186214
"""
@@ -190,10 +218,37 @@ def _validate_url(self, base_url):
190218
emit a warning. old documentation told users to pass in the full
191219
inference url, which is incorrect and prevents model listing from working.
192220
normalize base_url to end in /v1.
221+
validate the base_url.
222+
if the base_url is not a url, raise an error
223+
if the base_url does not end in /v1, e.g. /embeddings
224+
emit a warning. old documentation told users to pass in the full
225+
inference url, which is incorrect and prevents model listing from working.
226+
normalize base_url to end in /v1.
193227
"""
194228
if base_url is not None:
195229
parsed = urlparse(base_url)
196230

231+
# Ensure scheme and netloc (domain name) are present
232+
if not (parsed.scheme and parsed.netloc):
233+
expected_format = "Expected format is: http://host:port"
234+
raise ValueError(
235+
f"Invalid base_url format. {expected_format} Got: {base_url}"
236+
)
237+
238+
normalized_path = parsed.path.rstrip("/")
239+
if not normalized_path.endswith("/v1"):
240+
warnings.warn(
241+
f"{base_url} does not end in /v1, you may "
242+
"have inference and listing issues"
243+
)
244+
normalized_path += "/v1"
245+
246+
base_url = urlunparse(
247+
(parsed.scheme, parsed.netloc, normalized_path, None, None, None)
248+
)
249+
if base_url is not None:
250+
parsed = urlparse(base_url)
251+
197252
# Ensure scheme and netloc (domain name) are present
198253
if not (parsed.scheme and parsed.netloc):
199254
expected_format = "Expected format is: http://host:port"
@@ -228,6 +283,15 @@ def _validate_model(self, model_name: str) -> None:
228283
model = determine_model(model_name)
229284
available_model_ids = [model.id for model in self.available_models]
230285

286+
if not model:
287+
if self._is_hosted:
288+
warnings.warn(f"Unable to determine validity of {model_name}")
289+
else:
290+
if model_name not in available_model_ids:
291+
raise ValueError(f"No locally hosted {model_name} was found.")
292+
model = determine_model(model_name)
293+
available_model_ids = [model.id for model in self.available_models]
294+
231295
if not model:
232296
if self._is_hosted:
233297
warnings.warn(f"Unable to determine validity of {model_name}")
@@ -238,16 +302,29 @@ def _validate_model(self, model_name: str) -> None:
238302
if model and model.endpoint:
239303
self.base_url = model.endpoint
240304

305+
if model and model.endpoint:
306+
self.base_url = model.endpoint
307+
241308
@property
242309
def available_models(self) -> List[Model]:
243310
"""Get available models."""
244311
# all available models are in the map
245312
ids = RANKING_MODEL_TABLE.keys()
313+
ids = RANKING_MODEL_TABLE.keys()
246314
if not self._is_hosted:
247315
return self._get_models()
248316
else:
249317
return [Model(id=id) for id in ids]
250318

319+
@property
320+
def client(self) -> httpx.Client:
321+
"""
322+
Lazy initialization of the HTTP client.
323+
"""
324+
if self._http_client is None:
325+
self._http_client = httpx.Client()
326+
return self._http_client
327+
251328
@classmethod
252329
def class_name(cls) -> str:
253330
return "NVIDIARerank"
@@ -273,12 +350,8 @@ def _postprocess_nodes(
273350
if len(nodes) == 0:
274351
return []
275352

276-
session = requests.Session()
277-
278-
_headers = {
279-
"Authorization": f"Bearer {self._api_key}",
280-
"Accept": "application/json",
281-
}
353+
client = self.client
354+
_headers = self._get_headers(auth_required=True)
282355

283356
# TODO: replace with itertools.batched in python 3.12
284357
def batched(ls: list, size: int) -> Generator[List[NodeWithScore], None, None]:
@@ -305,7 +378,7 @@ def batched(ls: list, size: int) -> Generator[List[NodeWithScore], None, None]:
305378
for n in batch
306379
],
307380
}
308-
response = session.post(self.base_url, headers=_headers, json=payloads)
381+
response = client.post(self.base_url, headers=_headers, json=payloads)
309382
response.raise_for_status()
310383
# expected response format:
311384
# {

llama-index-integrations/postprocessor/llama-index-postprocessor-nvidia-rerank/pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ license = "MIT"
3030
name = "llama-index-postprocessor-nvidia-rerank"
3131
packages = [{include = "llama_index/"}]
3232
readme = "README.md"
33-
version = "0.4.1"
33+
version = "0.4.2"
3434

3535
[tool.poetry.dependencies]
3636
python = ">=3.9,<4.0"
@@ -56,6 +56,10 @@ types-redis = "4.5.5.0"
5656
types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991
5757
types-setuptools = "67.1.0.0"
5858

59+
[tool.poetry.group.test-integration.dependencies]
60+
responses = "^0.25.6"
61+
respx = {extras = ["pytest"], version = "^0.22.0"}
62+
5963
[tool.poetry.group.test_integration.dependencies]
6064
pytest-httpx = "*"
6165
requests-mock = "^1.12.1"

llama-index-integrations/postprocessor/llama-index-postprocessor-nvidia-rerank/tests/test_api_key.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
11
import os
22

33
import pytest
4-
4+
import respx
55
from llama_index.postprocessor.nvidia_rerank import NVIDIARerank as Interface
66
from llama_index.core.schema import NodeWithScore, Document
77

88
from typing import Any
9-
from requests_mock import Mocker
109

1110

1211
@pytest.fixture()
13-
def mock_local_models(requests_mock: Mocker) -> None:
14-
requests_mock.get(
15-
"https://test_url/v1/models",
16-
json={
17-
"data": [
18-
{"id": "model1"},
19-
]
20-
},
12+
def mock_local_models(respx_mock: respx.MockRouter) -> None:
13+
respx_mock.get("https://test_url/v1/models").respond(
14+
json={"data": [{"id": "model1"}]}
2115
)
2216

2317

llama-index-integrations/postprocessor/llama-index-postprocessor-nvidia-rerank/tests/test_available_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import pytest
22

33
from llama_index.postprocessor.nvidia_rerank import NVIDIARerank
4-
from requests_mock import Mocker
4+
import respx
55

66

77
@pytest.fixture(autouse=True)
8-
def mock_local_models(requests_mock: Mocker) -> None:
9-
requests_mock.get(
8+
def mock_local_models(respx_mock: respx.MockRouter) -> None:
9+
respx_mock.get(
1010
"https://test_url/v1/models",
1111
json={
1212
"data": [

0 commit comments

Comments
 (0)