Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions ee/vellum_ee/workflows/display/tests/test_base_workflow_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,3 +843,104 @@ def test_serialize_module__chat_message_prompt_block_validation_error():
assert any(
"validation error" in msg.lower() for msg in error_messages
), f"Expected validation error in error messages, got: {error_messages}"


def test_parse_ml_models__skips_invalid_models_and_returns_valid_ones(caplog):
"""
Tests that _parse_ml_models skips models with validation errors and returns valid ones.
"""

# GIVEN a workflow class
class ExampleWorkflow(BaseWorkflow):
pass

# AND a mix of valid and invalid ml_models data
ml_models_raw = [
{"name": "gpt-4", "hosted_by": "OPENAI"},
{"name": "invalid-model", "hosted_by": "INVALID_PROVIDER"},
{"name": "claude-3", "hosted_by": "ANTHROPIC"},
{"name": "missing-hosted-by"},
{"name": "gemini", "hosted_by": "GOOGLE"},
]

# WHEN creating a display with the ml_models
display = BaseWorkflowDisplay[ExampleWorkflow](ml_models=ml_models_raw)

# THEN only the valid models should be parsed
assert len(display._ml_models) == 3

# AND the valid models should have the correct names
model_names = [model.name for model in display._ml_models]
assert "gpt-4" in model_names
assert "claude-3" in model_names
assert "gemini" in model_names

# AND the invalid models should not be included
assert "invalid-model" not in model_names
assert "missing-hosted-by" not in model_names

# AND warnings should be logged for the invalid models
warning_messages = [record.message for record in caplog.records if record.levelname == "WARNING"]
assert len(warning_messages) == 2
assert any("invalid-model" in msg for msg in warning_messages)
assert any("missing-hosted-by" in msg for msg in warning_messages)


def test_parse_ml_models__returns_empty_list_when_all_models_invalid(caplog):
"""
Tests that _parse_ml_models returns an empty list when all models fail validation.
"""

# GIVEN a workflow class
class ExampleWorkflow(BaseWorkflow):
pass

# AND ml_models data where all models are invalid
ml_models_raw = [
{"name": "invalid-model", "hosted_by": "INVALID_PROVIDER"},
{"name": "missing-hosted-by"},
{"hosted_by": "OPENAI"},
]

# WHEN creating a display with the ml_models
display = BaseWorkflowDisplay[ExampleWorkflow](ml_models=ml_models_raw)

# THEN no models should be parsed
assert len(display._ml_models) == 0

# AND warnings should be logged for all invalid models
warning_messages = [record.message for record in caplog.records if record.levelname == "WARNING"]
assert len(warning_messages) == 3


def test_parse_ml_models__returns_all_models_when_all_valid(caplog):
"""
Tests that _parse_ml_models returns all models when all are valid.
"""

# GIVEN a workflow class
class ExampleWorkflow(BaseWorkflow):
pass

# AND ml_models data where all models are valid
ml_models_raw = [
{"name": "gpt-4", "hosted_by": "OPENAI"},
{"name": "claude-3", "hosted_by": "ANTHROPIC"},
{"name": "gemini", "hosted_by": "GOOGLE"},
]

# WHEN creating a display with the ml_models
display = BaseWorkflowDisplay[ExampleWorkflow](ml_models=ml_models_raw)

# THEN all models should be parsed
assert len(display._ml_models) == 3

# AND the models should have the correct names and hosted_by values
model_data = [(model.name, model.hosted_by.value) for model in display._ml_models]
assert ("gpt-4", "OPENAI") in model_data
assert ("claude-3", "ANTHROPIC") in model_data
assert ("gemini", "GOOGLE") in model_data

# AND no warnings should be logged
warning_messages = [record.message for record in caplog.records if record.levelname == "WARNING"]
assert len(warning_messages) == 0
14 changes: 12 additions & 2 deletions ee/vellum_ee/workflows/display/workflows/base_workflow_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,18 @@ def __init__(
self._ml_models = self._parse_ml_models(ml_models) if ml_models else []

def _parse_ml_models(self, ml_models_raw: list) -> List[MLModel]:
"""Parse raw list of dicts into MLModel instances using pydantic deserialization."""
return [MLModel.model_validate(item) for item in ml_models_raw]
"""Parse raw list of dicts into MLModel instances using pydantic deserialization.

Models that fail validation are skipped and a warning is logged.
"""
parsed_models: List[MLModel] = []
for item in ml_models_raw:
try:
parsed_models.append(MLModel.model_validate(item))
except Exception as e:
model_name = item.get("name") if isinstance(item, dict) else None
logger.warning(f"Skipping ML model '{model_name}' due to validation error: {type(e).__name__}")
Comment on lines +221 to +225

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Limit exception handling to validation errors

Catching Exception here will silently skip models on any unexpected error raised inside MLModel.model_validate, including programming errors or runtime issues unrelated to input validation. That can mask real regressions (e.g., a bug in validation logic) by downgrading them to warnings and dropping models, which makes failures harder to detect. Since the intent is only to ignore validation failures, consider catching the specific Pydantic validation exception(s) instead.

Useful? React with 👍 / 👎.

return parsed_models

def serialize(self) -> JsonObject:
try:
Expand Down