Skip to content

Commit 8b55772

Browse files
[APO-2934] Make _parse_ml_models resilient to validation errors (#3702)
* [APO-2934] Make _parse_ml_models resilient to validation errors Co-Authored-By: vargas@vellum.ai <vargas@vellum.ai> * [APO-2934] Improve error handling and add warning log verification Co-Authored-By: vargas@vellum.ai <vargas@vellum.ai> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: vargas@vellum.ai <vargas@vellum.ai>
1 parent 430eb9a commit 8b55772

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

ee/vellum_ee/workflows/display/tests/test_base_workflow_display.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,3 +843,104 @@ def test_serialize_module__chat_message_prompt_block_validation_error():
843843
assert any(
844844
"validation error" in msg.lower() for msg in error_messages
845845
), f"Expected validation error in error messages, got: {error_messages}"
846+
847+
848+
def test_parse_ml_models__skips_invalid_models_and_returns_valid_ones(caplog):
849+
"""
850+
Tests that _parse_ml_models skips models with validation errors and returns valid ones.
851+
"""
852+
853+
# GIVEN a workflow class
854+
class ExampleWorkflow(BaseWorkflow):
855+
pass
856+
857+
# AND a mix of valid and invalid ml_models data
858+
ml_models_raw = [
859+
{"name": "gpt-4", "hosted_by": "OPENAI"},
860+
{"name": "invalid-model", "hosted_by": "INVALID_PROVIDER"},
861+
{"name": "claude-3", "hosted_by": "ANTHROPIC"},
862+
{"name": "missing-hosted-by"},
863+
{"name": "gemini", "hosted_by": "GOOGLE"},
864+
]
865+
866+
# WHEN creating a display with the ml_models
867+
display = BaseWorkflowDisplay[ExampleWorkflow](ml_models=ml_models_raw)
868+
869+
# THEN only the valid models should be parsed
870+
assert len(display._ml_models) == 3
871+
872+
# AND the valid models should have the correct names
873+
model_names = [model.name for model in display._ml_models]
874+
assert "gpt-4" in model_names
875+
assert "claude-3" in model_names
876+
assert "gemini" in model_names
877+
878+
# AND the invalid models should not be included
879+
assert "invalid-model" not in model_names
880+
assert "missing-hosted-by" not in model_names
881+
882+
# AND warnings should be logged for the invalid models
883+
warning_messages = [record.message for record in caplog.records if record.levelname == "WARNING"]
884+
assert len(warning_messages) == 2
885+
assert any("invalid-model" in msg for msg in warning_messages)
886+
assert any("missing-hosted-by" in msg for msg in warning_messages)
887+
888+
889+
def test_parse_ml_models__returns_empty_list_when_all_models_invalid(caplog):
890+
"""
891+
Tests that _parse_ml_models returns an empty list when all models fail validation.
892+
"""
893+
894+
# GIVEN a workflow class
895+
class ExampleWorkflow(BaseWorkflow):
896+
pass
897+
898+
# AND ml_models data where all models are invalid
899+
ml_models_raw = [
900+
{"name": "invalid-model", "hosted_by": "INVALID_PROVIDER"},
901+
{"name": "missing-hosted-by"},
902+
{"hosted_by": "OPENAI"},
903+
]
904+
905+
# WHEN creating a display with the ml_models
906+
display = BaseWorkflowDisplay[ExampleWorkflow](ml_models=ml_models_raw)
907+
908+
# THEN no models should be parsed
909+
assert len(display._ml_models) == 0
910+
911+
# AND warnings should be logged for all invalid models
912+
warning_messages = [record.message for record in caplog.records if record.levelname == "WARNING"]
913+
assert len(warning_messages) == 3
914+
915+
916+
def test_parse_ml_models__returns_all_models_when_all_valid(caplog):
917+
"""
918+
Tests that _parse_ml_models returns all models when all are valid.
919+
"""
920+
921+
# GIVEN a workflow class
922+
class ExampleWorkflow(BaseWorkflow):
923+
pass
924+
925+
# AND ml_models data where all models are valid
926+
ml_models_raw = [
927+
{"name": "gpt-4", "hosted_by": "OPENAI"},
928+
{"name": "claude-3", "hosted_by": "ANTHROPIC"},
929+
{"name": "gemini", "hosted_by": "GOOGLE"},
930+
]
931+
932+
# WHEN creating a display with the ml_models
933+
display = BaseWorkflowDisplay[ExampleWorkflow](ml_models=ml_models_raw)
934+
935+
# THEN all models should be parsed
936+
assert len(display._ml_models) == 3
937+
938+
# AND the models should have the correct names and hosted_by values
939+
model_data = [(model.name, model.hosted_by.value) for model in display._ml_models]
940+
assert ("gpt-4", "OPENAI") in model_data
941+
assert ("claude-3", "ANTHROPIC") in model_data
942+
assert ("gemini", "GOOGLE") in model_data
943+
944+
# AND no warnings should be logged
945+
warning_messages = [record.message for record in caplog.records if record.levelname == "WARNING"]
946+
assert len(warning_messages) == 0

ee/vellum_ee/workflows/display/workflows/base_workflow_display.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,18 @@ def __init__(
212212
self._ml_models = self._parse_ml_models(ml_models) if ml_models else []
213213

214214
def _parse_ml_models(self, ml_models_raw: list) -> List[MLModel]:
215-
"""Parse raw list of dicts into MLModel instances using pydantic deserialization."""
216-
return [MLModel.model_validate(item) for item in ml_models_raw]
215+
"""Parse raw list of dicts into MLModel instances using pydantic deserialization.
216+
217+
Models that fail validation are skipped and a warning is logged.
218+
"""
219+
parsed_models: List[MLModel] = []
220+
for item in ml_models_raw:
221+
try:
222+
parsed_models.append(MLModel.model_validate(item))
223+
except Exception as e:
224+
model_name = item.get("name") if isinstance(item, dict) else None
225+
logger.warning(f"Skipping ML model '{model_name}' due to validation error: {type(e).__name__}")
226+
return parsed_models
217227

218228
def serialize(self) -> JsonObject:
219229
try:

0 commit comments

Comments
 (0)