diff --git a/inference/core/workflows/execution_engine/introspection/schema_parser.py b/inference/core/workflows/execution_engine/introspection/schema_parser.py index 2dc7b40b8c..085f5a5ead 100644 --- a/inference/core/workflows/execution_engine/introspection/schema_parser.py +++ b/inference/core/workflows/execution_engine/introspection/schema_parser.py @@ -432,6 +432,7 @@ def retrieve_selectors_from_union_definition( inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars, inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting, is_list_element=is_list_element, + is_dict_element=is_dict_element, ) if result is None: continue @@ -464,12 +465,14 @@ def retrieve_selectors_from_union_definition( ) if not merged_references: return None + merged_is_list_element = is_list_element or any(r.is_list_element for r in results) + merged_is_dict_element = is_dict_element or any(r.is_dict_element for r in results) return SelectorDefinition( property_name=property_name, property_description=property_description, allowed_references=merged_references, - is_list_element=is_list_element, - is_dict_element=is_dict_element, + is_list_element=merged_is_list_element, + is_dict_element=merged_is_dict_element, dimensionality_offset=property_dimensionality_offset, is_dimensionality_reference_property=is_dimensionality_reference_property, ) diff --git a/inference/core/workflows/execution_engine/introspection/selectors_parser.py b/inference/core/workflows/execution_engine/introspection/selectors_parser.py index b8d4b788bf..af7ce80882 100644 --- a/inference/core/workflows/execution_engine/introspection/selectors_parser.py +++ b/inference/core/workflows/execution_engine/introspection/selectors_parser.py @@ -23,27 +23,30 @@ def get_step_selectors( step_manifest=step_manifest, property_name=property_name, ) - if selector_definition.is_list_element: + # Check runtime type to handle Union[List[...], Selector(...)] patterns + # where the actual value determines which path to take + if selector_definition.is_list_element and isinstance(property_value, list): selectors = retrieve_selectors_from_array( step_name=step_manifest.name, property_value=property_value, selector_definition=selector_definition, ) result.extend(selectors) - elif selector_definition.is_dict_element: + elif selector_definition.is_dict_element and isinstance(property_value, dict): selectors = retrieve_selectors_from_dictionary( step_name=step_manifest.name, property_value=property_value, selector_definition=selector_definition, ) result.extend(selectors) - else: + elif is_selector(property_value): selector = retrieve_selector_from_simple_property( step_name=step_manifest.name, property_value=property_value, selector_definition=selector_definition, ) result.append(selector) + # If none of the above, property_value is not a selector return [r for r in result if r is not None] diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_union_list_selector.py b/tests/workflows/integration_tests/execution/test_workflow_with_union_list_selector.py new file mode 100644 index 0000000000..420305961f --- /dev/null +++ b/tests/workflows/integration_tests/execution/test_workflow_with_union_list_selector.py @@ -0,0 +1,160 @@ +"""Integration tests for Union[List[...], Selector(...)] patterns.""" + +import numpy as np +import pytest + +from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS +from inference.core.managers.base import ModelManager +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.execution_engine.core import ExecutionEngine + + +WORKFLOW_WITH_SELECTOR_TO_LIST = { + "version": "1.0", + "inputs": [ + {"type": "WorkflowImage", "name": "image"}, + { + "type": "WorkflowParameter", + "name": "classes_to_consider", + }, + ], + "steps": [ + { + "type": "RoboflowObjectDetectionModel", + "name": "detection", + "image": "$inputs.image", + "model_id": "yolov8n-640", + }, + { + "type": "DetectionsConsensus", + "name": "consensus", + "predictions_batches": ["$steps.detection.predictions"], + "required_votes": 1, + "classes_to_consider": "$inputs.classes_to_consider", + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "predictions", + "selector": "$steps.consensus.predictions", + } + ], +} + + +WORKFLOW_WITH_LITERAL_LIST = { + "version": "1.0", + "inputs": [ + {"type": "WorkflowImage", "name": "image"}, + ], + "steps": [ + { + "type": "RoboflowObjectDetectionModel", + "name": "detection", + "image": "$inputs.image", + "model_id": "yolov8n-640", + }, + { + "type": "DetectionsConsensus", + "name": "consensus", + "predictions_batches": ["$steps.detection.predictions"], + "required_votes": 1, + "classes_to_consider": ["person"], + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "predictions", + "selector": "$steps.consensus.predictions", + } + ], +} + + +def test_union_list_selector_with_selector_to_list( + model_manager: ModelManager, + crowd_image: np.ndarray, +) -> None: + """Test Union[List[str], Selector(...)] when using a selector to a list.""" + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={ + "image": crowd_image, + "classes_to_consider": ["person"], + } + ) + + # then + assert isinstance(result, list), "Expected list of results" + assert len(result) == 1, "Expected single result" + assert "predictions" in result[0], "Expected predictions in output" + # Verify that the selector was properly resolved and the workflow executed + + +def test_union_list_selector_with_literal_list( + model_manager: ModelManager, + crowd_image: np.ndarray, +) -> None: + """Test Union[List[str], Selector(...)] when using a literal list.""" + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_WITH_LITERAL_LIST, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={ + "image": crowd_image, + } + ) + + # then + assert isinstance(result, list), "Expected list of results" + assert len(result) == 1, "Expected single result" + assert "predictions" in result[0], "Expected predictions in output" + # Verify that the literal list was properly handled + + +def test_union_list_selector_validates_type_mismatch( + model_manager: ModelManager, + crowd_image: np.ndarray, +) -> None: + """Test that type validation catches invalid selector resolution.""" + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when/then - passing a string instead of a list should fail validation + with pytest.raises(Exception): # Should raise validation error + execution_engine.run( + runtime_parameters={ + "image": crowd_image, + "classes_to_consider": "person", # String instead of list + } + ) diff --git a/tests/workflows/unit_tests/execution_engine/introspection/test_schema_parser.py b/tests/workflows/unit_tests/execution_engine/introspection/test_schema_parser.py index 6d91895818..4ae47a8546 100644 --- a/tests/workflows/unit_tests/execution_engine/introspection/test_schema_parser.py +++ b/tests/workflows/unit_tests/execution_engine/introspection/test_schema_parser.py @@ -6,8 +6,10 @@ from inference.core.workflows.execution_engine.entities.types import ( BOOLEAN_KIND, IMAGE_KIND, + LIST_OF_VALUES_KIND, OBJECT_DETECTION_PREDICTION_KIND, STRING_KIND, + Selector, StepOutputImageSelector, StepOutputSelector, StepSelector, @@ -531,3 +533,83 @@ def describe_outputs(cls) -> List[OutputDefinition]: ) }, ) + + +def test_parse_block_manifest_when_manifest_defines_union_of_list_str_or_selector() -> ( + None +): + # Union[List[str], Selector] should have is_list_element=False + # because List[str] doesn't allow selectors inside + class Manifest(WorkflowBlockManifest): + type: Literal["MyManifest"] + name: str = Field(description="name field") + tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field( + description="Tags can be a literal list or a selector to a list" + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + manifest_metadata = parse_block_manifest(manifest_type=Manifest) + + assert manifest_metadata == BlockManifestMetadata( + primitive_types={ + "name": PrimitiveTypeDefinition( + property_name="name", + property_description="name field", + type_annotation="str", + ), + "tags": PrimitiveTypeDefinition( + property_name="tags", + property_description="Tags can be a literal list or a selector to a list", + type_annotation="List[str]", + ), + }, + selectors={ + "tags": SelectorDefinition( + property_name="tags", + property_description="Tags can be a literal list or a selector to a list", + allowed_references=[ + ReferenceDefinition( + selected_element="any_data", + kind=[LIST_OF_VALUES_KIND], + points_to_batch={False}, + ), + ], + is_list_element=False, + is_dict_element=False, + dimensionality_offset=0, + is_dimensionality_reference_property=False, + ) + }, + ) + + +def test_parse_block_manifest_when_manifest_defines_union_of_list_with_selectors_or_selector() -> ( + None +): + # Union[List[Union[Selector, str]], Selector] should have is_list_element=True + # because List items can contain selectors (like DatasetUpload.registration_tags) + class Manifest(WorkflowBlockManifest): + type: Literal["MyManifest"] + name: str = Field(description="name field") + registration_tags: Union[ + List[Union[Selector(kind=[STRING_KIND]), str]], + Selector(kind=[LIST_OF_VALUES_KIND]), + ] = Field(description="Tags with selectors inside the list") + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + manifest_metadata = parse_block_manifest(manifest_type=Manifest) + + assert "registration_tags" in manifest_metadata.selectors + selector_def = manifest_metadata.selectors["registration_tags"] + assert selector_def.is_list_element is True + assert selector_def.is_dict_element is False + + all_kinds = {k.name for ref in selector_def.allowed_references for k in ref.kind} + assert STRING_KIND.name in all_kinds + assert LIST_OF_VALUES_KIND.name in all_kinds diff --git a/tests/workflows/unit_tests/execution_engine/introspection/test_selectors_parser.py b/tests/workflows/unit_tests/execution_engine/introspection/test_selectors_parser.py index fe020eb78c..71f06232fc 100644 --- a/tests/workflows/unit_tests/execution_engine/introspection/test_selectors_parser.py +++ b/tests/workflows/unit_tests/execution_engine/introspection/test_selectors_parser.py @@ -6,7 +6,9 @@ from inference.core.workflows.execution_engine.entities.types import ( BOOLEAN_KIND, IMAGE_KIND, + LIST_OF_VALUES_KIND, STRING_KIND, + Selector, StepOutputSelector, WorkflowImageSelector, WorkflowParameterSelector, @@ -160,3 +162,132 @@ def describe_outputs(cls) -> List[OutputDefinition]: assert ( selectors[0].definition.property_name == "param" ), "Selector definition must hold in terms of property name" + + +def test_get_step_selectors_when_union_of_list_str_or_selector_receives_selector() -> ( + None +): + # Union[List[str], Selector] with a selector value + class Manifest(WorkflowBlockManifest): + type: Literal["UnionListTest"] + name: str = Field(description="name field") + tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field( + description="Tags can be a literal list or a selector to a list" + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + step_manifest = Manifest(type="UnionListTest", name="my_step", tags="$inputs.tags") + selectors = get_step_selectors(step_manifest=step_manifest) + + assert len(selectors) == 1 + assert selectors[0].value == "$inputs.tags" + assert selectors[0].definition.is_list_element is False + assert selectors[0].index is None + + +def test_get_step_selectors_when_union_of_list_str_or_selector_receives_literal_list() -> ( + None +): + # Union[List[str], Selector] with a literal list - no selectors + class Manifest(WorkflowBlockManifest): + type: Literal["UnionListTest"] + name: str = Field(description="name field") + tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field( + description="Tags can be a literal list or a selector to a list" + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + step_manifest = Manifest( + type="UnionListTest", name="my_step", tags=["tag1", "tag2", "tag3"] + ) + selectors = get_step_selectors(step_manifest=step_manifest) + + assert len(selectors) == 0 + + +def test_get_step_selectors_when_union_of_list_str_or_selector_receives_list_with_selector_like_strings() -> ( + None +): + # List[str] should NOT parse selector-like strings - they're just strings + class Manifest(WorkflowBlockManifest): + type: Literal["UnionListTest"] + name: str = Field(description="name field") + tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field( + description="Tags can be a literal list or a selector to a list" + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + step_manifest = Manifest( + type="UnionListTest", + name="my_step", + tags=["literal_tag", "$inputs.tag", "$inputs.another_tag"], + ) + selectors = get_step_selectors(step_manifest=step_manifest) + + # these look like selectors but List[str] means they're literal strings + assert len(selectors) == 0 + + +def test_get_step_selectors_when_union_of_list_with_selectors_or_selector_receives_mixed_list() -> ( + None +): + # List[Union[Selector, str]] SHOULD parse selectors inside the list + class Manifest(WorkflowBlockManifest): + type: Literal["UnionListTest"] + name: str = Field(description="name field") + registration_tags: Union[ + List[Union[Selector(kind=[STRING_KIND]), str]], + Selector(kind=[LIST_OF_VALUES_KIND]), + ] = Field(description="Tags with selectors inside the list") + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + step_manifest = Manifest( + type="UnionListTest", + name="my_step", + registration_tags=["literal_tag", "$inputs.tag", "$inputs.another_tag"], + ) + selectors = get_step_selectors(step_manifest=step_manifest) + + assert len(selectors) == 2 + assert selectors[0].value == "$inputs.tag" + assert selectors[0].index == 1 + assert selectors[1].value == "$inputs.another_tag" + assert selectors[1].index == 2 + + +def test_get_step_selectors_when_union_of_list_with_selectors_or_selector_receives_direct_selector() -> ( + None +): + # Union[List[Union[Selector, str]], Selector] with a direct selector + class Manifest(WorkflowBlockManifest): + type: Literal["UnionListTest"] + name: str = Field(description="name field") + registration_tags: Union[ + List[Union[Selector(kind=[STRING_KIND]), str]], + Selector(kind=[LIST_OF_VALUES_KIND]), + ] = Field(description="Tags with selectors inside the list") + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + step_manifest = Manifest( + type="UnionListTest", name="my_step", registration_tags="$inputs.tags" + ) + selectors = get_step_selectors(step_manifest=step_manifest) + + assert len(selectors) == 1 + assert selectors[0].value == "$inputs.tags" + assert selectors[0].index is None