Skip to content

Commit 60998f6

Browse files
authored
Merge pull request #3 from waggle-sensor/feature-nrp-llm
Add NRP LLM caption function
2 parents 9c0a6fb + c81c60e commit 60998f6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+32678
-10
lines changed

.env.example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ UNALLOWED_NODES="W042,N001,V012,W015,W01C,W01E,W024,W026,W02C,W02D,W02E,W02F,W03
4545
LOG_LEVEL='INFO'
4646
MONITOR_DATA_STREAM_INTERVAL=60
4747
MONITOR_DATA_STREAM_QUERY_DELAY_MINUTE=5
48+
LLM_RUN_MODE=TRITON
49+
# IF LLM_RUN_MODE is set to NRP
50+
# NRP_API_KEY=
51+
# NRP_API_ENDPOINT=
52+
# NRP_LLM_MODEL=
4853

4954
#gardio-ui
5055
WEAVIATE_HOST=weaviate

benchmarking/benchmarks/INQUIRE/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def __init__(self):
3535
self._weaviate_grpc_port = os.environ.get("WEAVIATE_GRPC_PORT", "50051")
3636
self._collection_name = os.environ.get("COLLECTION_NAME", "INQUIRE")
3737

38+
# model provider parameters
39+
self._llm_model_provider = os.environ.get("LLM_MODEL_PROVIDER", "triton").lower()
40+
3841
# Triton parameters
3942
self._triton_host = os.environ.get("TRITON_HOST", "triton")
4043
self._triton_port = os.environ.get("TRITON_PORT", "8001")
@@ -97,3 +100,9 @@ def __init__(self):
97100
keywords: <keyword1>, <keyword2>, ...
98101
"""
99102
self.gemma3_prompt = os.environ.get("GEMMA3_PROMPT", default_prompt)
103+
104+
@staticmethod
105+
def is_nrp_key_set():
106+
"""Check if NRP API key is set."""
107+
if os.environ.get("NRP_API_KEY", "") == "":
108+
raise ValueError("NRP_API_KEY is not set")
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from imsearch_eval.adapters import TritonModelProvider, NRPModelProvider
2+
from imsearch_eval.framework import Config
3+
import os
4+
from PIL import Image
5+
from typing import Optional
6+
7+
class MixedModelProvider(NRPModelProvider):
8+
"""
9+
Mixed model provider using NRPModelProvider and TritonModelProvider.
10+
11+
NRPModelProvider is used for caption generation and TritonModelProvider is used for embedding generation.
12+
"""
13+
14+
def __init__(
15+
self,
16+
api_key: str = os.environ.get("NRP_API_KEY"),
17+
base_url: str = "https://ellm.nrp-nautilus.io/v1",
18+
triton_model_provider: TritonModelProvider = None,
19+
config: Config = None,
20+
**client_kwargs,
21+
):
22+
"""
23+
Initialize Mixed model provider.
24+
25+
Args:
26+
api_key: NRP API token (defaults to environment variable "NRP_API_KEY").
27+
base_url: Envoy gateway URL. Defaults to the NRP-managed LLM endpoint.
28+
triton_model_provider: Triton model provider.
29+
config: Config object.
30+
**client_kwargs: Optional extra arguments passed to the NRPModelProvider.
31+
"""
32+
super().__init__(api_key=api_key, base_url=base_url, **client_kwargs)
33+
self.triton_model_provider = triton_model_provider
34+
self.config = config
35+
36+
# determine which model provider to use for caption generation
37+
if self.config._llm_model_provider == "triton":
38+
self.model_utils = self.triton_model_provider.model_utils
39+
elif self.config._llm_model_provider == "nrp":
40+
self.config.is_nrp_key_set()
41+
else:
42+
raise ValueError(f"Invalid model provider: {self.config._llm_model_provider} not supported")
43+
44+
45+
def get_embedding(
46+
self,
47+
text: str,
48+
image: Optional[Image.Image] = None,
49+
model_name: str = "clip"
50+
):
51+
"""
52+
Get embedding for text and/or image using triton model provider.
53+
54+
Args:
55+
text: Text to embed
56+
image: Optional PIL Image to embed
57+
model_name: Name of the model to use ("clip", "colbert", "align")
58+
59+
Returns:
60+
Embedding vector (numpy array)
61+
"""
62+
return self.triton_model_provider.get_embedding(text, image, model_name)

benchmarking/benchmarks/INQUIRE/requirements.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# INQUIRE Benchmark Requirements
22
# Core benchmarking framework
3-
imsearch_eval[weaviate] @ git+https://github.com/waggle-sensor/imsearch_eval.git@0.1.0
4-
imsearch_eval[triton] @ git+https://github.com/waggle-sensor/imsearch_eval.git@0.1.0
5-
imsearch_eval[huggingface] @ git+https://github.com/waggle-sensor/imsearch_eval.git@0.1.0
3+
imsearch_eval[weaviate] @ git+https://github.com/waggle-sensor/imsearch_eval.git@0.1.1
4+
imsearch_eval[triton] @ git+https://github.com/waggle-sensor/imsearch_eval.git@0.1.1
5+
imsearch_eval[huggingface] @ git+https://github.com/waggle-sensor/imsearch_eval.git@0.1.1
6+
imsearch_eval[nrp] @ git+https://github.com/waggle-sensor/imsearch_eval.git@0.1.1
67

78
# Image processing
89
Pillow>=10.0.0

benchmarking/benchmarks/INQUIRE/results/v1/evaluate.ipynb renamed to benchmarking/benchmarks/INQUIRE/results/v01/evaluate.ipynb

File renamed without changes.

benchmarking/benchmarks/INQUIRE/results/v1/image_search_results.csv renamed to benchmarking/benchmarks/INQUIRE/results/v01/image_search_results.csv

File renamed without changes.

benchmarking/benchmarks/INQUIRE/results/v1/query_eval_metrics.csv renamed to benchmarking/benchmarks/INQUIRE/results/v01/query_eval_metrics.csv

File renamed without changes.

benchmarking/benchmarks/INQUIRE/results/v2/evaluate.ipynb renamed to benchmarking/benchmarks/INQUIRE/results/v02/evaluate.ipynb

File renamed without changes.

benchmarking/benchmarks/INQUIRE/results/v2/image_search_results.csv renamed to benchmarking/benchmarks/INQUIRE/results/v02/image_search_results.csv

File renamed without changes.

benchmarking/benchmarks/INQUIRE/results/v2/query_eval_metrics.csv renamed to benchmarking/benchmarks/INQUIRE/results/v02/query_eval_metrics.csv

File renamed without changes.

0 commit comments

Comments
 (0)