@@ -843,3 +843,104 @@ def test_serialize_module__chat_message_prompt_block_validation_error():
843843 assert any (
844844 "validation error" in msg .lower () for msg in error_messages
845845 ), f"Expected validation error in error messages, got: { error_messages } "
846+
847+
848+ def test_parse_ml_models__skips_invalid_models_and_returns_valid_ones (caplog ):
849+ """
850+ Tests that _parse_ml_models skips models with validation errors and returns valid ones.
851+ """
852+
853+ # GIVEN a workflow class
854+ class ExampleWorkflow (BaseWorkflow ):
855+ pass
856+
857+ # AND a mix of valid and invalid ml_models data
858+ ml_models_raw = [
859+ {"name" : "gpt-4" , "hosted_by" : "OPENAI" },
860+ {"name" : "invalid-model" , "hosted_by" : "INVALID_PROVIDER" },
861+ {"name" : "claude-3" , "hosted_by" : "ANTHROPIC" },
862+ {"name" : "missing-hosted-by" },
863+ {"name" : "gemini" , "hosted_by" : "GOOGLE" },
864+ ]
865+
866+ # WHEN creating a display with the ml_models
867+ display = BaseWorkflowDisplay [ExampleWorkflow ](ml_models = ml_models_raw )
868+
869+ # THEN only the valid models should be parsed
870+ assert len (display ._ml_models ) == 3
871+
872+ # AND the valid models should have the correct names
873+ model_names = [model .name for model in display ._ml_models ]
874+ assert "gpt-4" in model_names
875+ assert "claude-3" in model_names
876+ assert "gemini" in model_names
877+
878+ # AND the invalid models should not be included
879+ assert "invalid-model" not in model_names
880+ assert "missing-hosted-by" not in model_names
881+
882+ # AND warnings should be logged for the invalid models
883+ warning_messages = [record .message for record in caplog .records if record .levelname == "WARNING" ]
884+ assert len (warning_messages ) == 2
885+ assert any ("invalid-model" in msg for msg in warning_messages )
886+ assert any ("missing-hosted-by" in msg for msg in warning_messages )
887+
888+
889+ def test_parse_ml_models__returns_empty_list_when_all_models_invalid (caplog ):
890+ """
891+ Tests that _parse_ml_models returns an empty list when all models fail validation.
892+ """
893+
894+ # GIVEN a workflow class
895+ class ExampleWorkflow (BaseWorkflow ):
896+ pass
897+
898+ # AND ml_models data where all models are invalid
899+ ml_models_raw = [
900+ {"name" : "invalid-model" , "hosted_by" : "INVALID_PROVIDER" },
901+ {"name" : "missing-hosted-by" },
902+ {"hosted_by" : "OPENAI" },
903+ ]
904+
905+ # WHEN creating a display with the ml_models
906+ display = BaseWorkflowDisplay [ExampleWorkflow ](ml_models = ml_models_raw )
907+
908+ # THEN no models should be parsed
909+ assert len (display ._ml_models ) == 0
910+
911+ # AND warnings should be logged for all invalid models
912+ warning_messages = [record .message for record in caplog .records if record .levelname == "WARNING" ]
913+ assert len (warning_messages ) == 3
914+
915+
916+ def test_parse_ml_models__returns_all_models_when_all_valid (caplog ):
917+ """
918+ Tests that _parse_ml_models returns all models when all are valid.
919+ """
920+
921+ # GIVEN a workflow class
922+ class ExampleWorkflow (BaseWorkflow ):
923+ pass
924+
925+ # AND ml_models data where all models are valid
926+ ml_models_raw = [
927+ {"name" : "gpt-4" , "hosted_by" : "OPENAI" },
928+ {"name" : "claude-3" , "hosted_by" : "ANTHROPIC" },
929+ {"name" : "gemini" , "hosted_by" : "GOOGLE" },
930+ ]
931+
932+ # WHEN creating a display with the ml_models
933+ display = BaseWorkflowDisplay [ExampleWorkflow ](ml_models = ml_models_raw )
934+
935+ # THEN all models should be parsed
936+ assert len (display ._ml_models ) == 3
937+
938+ # AND the models should have the correct names and hosted_by values
939+ model_data = [(model .name , model .hosted_by .value ) for model in display ._ml_models ]
940+ assert ("gpt-4" , "OPENAI" ) in model_data
941+ assert ("claude-3" , "ANTHROPIC" ) in model_data
942+ assert ("gemini" , "GOOGLE" ) in model_data
943+
944+ # AND no warnings should be logged
945+ warning_messages = [record .message for record in caplog .records if record .levelname == "WARNING" ]
946+ assert len (warning_messages ) == 0
0 commit comments