Skip to content

Commit f01e740

Browse files
authored
feat: Fix model argument extraction in LMEvalCRBuilder (#53)
* feat: Fix model argument extraction in LMEvalCRBuilder * fix: Linting
1 parent 72cd231 commit f01e740

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

src/llama_stack_provider_lmeval/lmeval.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,21 @@ def _create_model_args(
358358
]
359359

360360
# Add model name if specified in benchmark config
361+
model_name = None
361362
if hasattr(benchmark_config, "model") and benchmark_config.model:
362-
model_args.append(ModelArg(name="model", value=benchmark_config.model))
363+
model_name = benchmark_config.model
364+
elif (
365+
hasattr(benchmark_config, "eval_candidate")
366+
and benchmark_config.eval_candidate
367+
):
368+
if (
369+
hasattr(benchmark_config.eval_candidate, "model")
370+
and benchmark_config.eval_candidate.model
371+
):
372+
model_name = benchmark_config.eval_candidate.model
373+
374+
if model_name:
375+
model_args.append(ModelArg(name="model", value=model_name))
363376

364377
# Add custom model args from benchmark config, avoiding duplicate keys
365378
if hasattr(benchmark_config, "model_args") and benchmark_config.model_args:

tests/test_lmeval.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,47 @@ def setUp(self):
605605
self.stored_benchmark = MagicMock()
606606
self.stored_benchmark.metadata = {}
607607

608+
def test_create_cr_with_model_in_eval_candidate(self):
609+
"""Test that model is correctly extracted from eval_candidate.model."""
610+
config = LMEvalEvalProviderConfig(
611+
namespace=self.namespace,
612+
service_account=self.service_account,
613+
)
614+
self.builder._config = config
615+
616+
# Create a benchmark config without direct model attribute
617+
benchmark_config = MagicMock()
618+
619+
# Create eval_candidate as a simple object with the required attributes
620+
class EvalCandidate:
621+
def __init__(self):
622+
self.type = "model"
623+
self.model = "eval-candidate-model"
624+
self.sampling_params = {}
625+
626+
eval_candidate = EvalCandidate()
627+
benchmark_config.eval_candidate = eval_candidate
628+
benchmark_config.env_vars = []
629+
benchmark_config.metadata = {}
630+
# Ensure hasattr works correctly for the mock
631+
benchmark_config.model = None
632+
# Don't set benchmark_config.model directly
633+
634+
cr = self.builder.create_cr(
635+
benchmark_id="lmeval::mmlu",
636+
task_config=benchmark_config,
637+
base_url="http://my-model-url",
638+
limit="10",
639+
stored_benchmark=self.stored_benchmark,
640+
)
641+
642+
model_args = cr.get("spec", {}).get("modelArgs", [])
643+
model_arg = next((arg for arg in model_args if arg.get("name") == "model"), None)
644+
645+
self.assertIsNotNone(model_arg, "Model argument should be present in modelArgs")
646+
self.assertEqual(model_arg.get("value"), "eval-candidate-model",
647+
"Model value should be extracted from eval_candidate.model")
648+
608649
@patch("src.llama_stack_provider_lmeval.lmeval.logger")
609650
def test_create_cr_without_tls(self, mock_logger):
610651
"""Creating CR without no TLS configuration."""

0 commit comments

Comments
 (0)