Skip to content

Commit 0a2e7f0

Browse files
Add LM-Inline provider and unit tests
1 parent 86ce2de commit 0a2e7f0

File tree

16 files changed

+1391
-160
lines changed

16 files changed

+1391
-160
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module: llama_stack_provider_lmeval.inline
2+
config_class: llama_stack_provider_lmeval.config.LMEvalEvalProviderConfig
3+
pip_packages: ["lm-eval"]
4+
api_dependencies: ["inference"]
5+
optional_api_dependencies: []

providers.d/remote/eval/trustyai_lmeval.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ adapter:
22
adapter_type: lmeval
33
pip_packages: ["kubernetes"]
44
config_class: llama_stack_provider_lmeval.config.LMEvalEvalProviderConfig
5-
module: llama_stack_provider_lmeval
5+
module: llama_stack_provider_lmeval.remote
66
api_dependencies: ["inference"]
77
optional_api_dependencies: []

run-inline.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
version: "2"
2+
image_name: trustyai-lmeval
3+
apis:
4+
- inference
5+
- eval
6+
providers:
7+
inference:
8+
- provider_id: vllm
9+
provider_type: remote::vllm
10+
config:
11+
url: ${env.VLLM_URL:=http://localhost:8080/v1}
12+
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
13+
api_token: ${env.VLLM_API_TOKEN:=fake}
14+
tls_verify: ${env.VLLM_TLS_VERIFY:=false}
15+
eval:
16+
- provider_id: trustyai_lmeval
17+
provider_type: inline::trustyai_lmeval
18+
config:
19+
base_url: ${env.BASE_URL:=http://localhost:8321/v1}
20+
use_k8s: ${env.USE_K8S:=false}
21+
# server:
22+
# port: ${env.PORT:=8321}
23+
external_providers_dir: ./providers.d
Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +0,0 @@
1-
import logging
2-
3-
from llama_stack.apis.datatypes import Api
4-
from llama_stack.providers.datatypes import ProviderSpec
5-
6-
from .config import LMEvalEvalProviderConfig
7-
from .lmeval import LMEval
8-
from .provider import get_provider_spec
9-
10-
# Set up logging
11-
logger = logging.getLogger(__name__)
12-
13-
14-
async def get_adapter_impl(
15-
config: LMEvalEvalProviderConfig,
16-
deps: dict[Api, ProviderSpec] | None = None,
17-
) -> LMEval:
18-
"""Get an LMEval implementation from the configuration.
19-
20-
Args:
21-
config: LMEval configuration
22-
deps: Optional dependencies for testing/injection
23-
24-
Returns:
25-
Configured LMEval implementation
26-
27-
Raises:
28-
Exception: If configuration is invalid
29-
"""
30-
try:
31-
if deps is None:
32-
deps = {}
33-
34-
# Extract base_url from config if available
35-
base_url = None
36-
if hasattr(config, "model_args") and config.model_args:
37-
for arg in config.model_args:
38-
if arg.get("name") == "base_url":
39-
base_url = arg.get("value")
40-
logger.debug(f"Using base_url from config: {base_url}")
41-
break
42-
43-
return LMEval(config=config)
44-
except Exception as e:
45-
raise Exception(f"Failed to create LMEval implementation: {str(e)}") from e
46-
47-
48-
__all__ = [
49-
# Factory methods
50-
"get_adapter_impl",
51-
# Configurations
52-
"LMEval",
53-
"get_provider_spec",
54-
]

src/llama_stack_provider_lmeval/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def __post_init__(self):
131131
if not isinstance(self.use_k8s, bool):
132132
raise LMEvalConfigError("use_k8s must be a boolean")
133133

134-
135134
__all__ = [
136135
"TLSConfig",
137136
"LMEvalBenchmarkConfig",
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import logging
2+
from typing import Optional
3+
4+
from llama_stack.apis.datatypes import Api
5+
from llama_stack.providers.datatypes import ProviderSpec
6+
7+
from llama_stack_provider_lmeval.config import LMEvalBenchmarkConfig
8+
from .lmeval import LMEvalInline
9+
10+
logger = logging.getLogger(__name__)
11+
12+
async def get_provider_impl(
13+
config: LMEvalBenchmarkConfig,
14+
deps: Optional[dict[Api, ProviderSpec]] = None,
15+
) -> LMEvalInline:
16+
"""Get an inline Eval implementation from the configuration.
17+
18+
Args:
19+
config: LMEvalInlineBenchmarkConfig
20+
deps: Optional[dict[Api, ProviderSpec]] = None
21+
22+
Returns:
23+
Configured LMEval Inline implementation
24+
25+
Raises:
26+
Exception: If configuration is invalid
27+
"""
28+
try:
29+
if deps is None:
30+
deps = {}
31+
32+
# Extract base_url from config if available
33+
base_url = None
34+
if hasattr(config, "model_args") and config.model_args:
35+
for arg in config.model_args:
36+
if arg.get("name") == "base_url":
37+
base_url = arg.get("value")
38+
logger.debug(f"Using base_url from config: {base_url}")
39+
break
40+
41+
return LMEvalInline(config=config)
42+
except Exception as e:
43+
raise Exception(f"Failed to create LMEval implementation: {str(e)}") from e
44+
45+
__all__ = [
46+
"get_provider_impl",
47+
"LMEvalInline",
48+
]

0 commit comments

Comments
 (0)