Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def retrieve_selectors_from_union_definition(
inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars,
inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting,
is_list_element=is_list_element,
is_dict_element=is_dict_element,
)
if result is None:
continue
Expand Down Expand Up @@ -464,12 +465,14 @@ def retrieve_selectors_from_union_definition(
)
if not merged_references:
return None
merged_is_list_element = is_list_element or any(r.is_list_element for r in results)
merged_is_dict_element = is_dict_element or any(r.is_dict_element for r in results)
return SelectorDefinition(
property_name=property_name,
property_description=property_description,
allowed_references=merged_references,
is_list_element=is_list_element,
is_dict_element=is_dict_element,
is_list_element=merged_is_list_element,
is_dict_element=merged_is_dict_element,
dimensionality_offset=property_dimensionality_offset,
is_dimensionality_reference_property=is_dimensionality_reference_property,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,30 @@ def get_step_selectors(
step_manifest=step_manifest,
property_name=property_name,
)
if selector_definition.is_list_element:
# Check runtime type to handle Union[List[...], Selector(...)] patterns
# where the actual value determines which path to take
if selector_definition.is_list_element and isinstance(property_value, list):
selectors = retrieve_selectors_from_array(
step_name=step_manifest.name,
property_value=property_value,
selector_definition=selector_definition,
)
result.extend(selectors)
elif selector_definition.is_dict_element:
elif selector_definition.is_dict_element and isinstance(property_value, dict):
selectors = retrieve_selectors_from_dictionary(
step_name=step_manifest.name,
property_value=property_value,
selector_definition=selector_definition,
)
result.extend(selectors)
else:
elif is_selector(property_value):
selector = retrieve_selector_from_simple_property(
step_name=step_manifest.name,
property_value=property_value,
selector_definition=selector_definition,
)
result.append(selector)
# If none of the above, property_value is not a selector
return [r for r in result if r is not None]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Integration tests for Union[List[...], Selector(...)] patterns."""

import numpy as np
import pytest

from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS
from inference.core.managers.base import ModelManager
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.execution_engine.core import ExecutionEngine


WORKFLOW_WITH_SELECTOR_TO_LIST = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
{
"type": "WorkflowParameter",
"name": "classes_to_consider",
},
],
"steps": [
{
"type": "RoboflowObjectDetectionModel",
"name": "detection",
"image": "$inputs.image",
"model_id": "yolov8n-640",
},
{
"type": "DetectionsConsensus",
"name": "consensus",
"predictions_batches": ["$steps.detection.predictions"],
"required_votes": 1,
"classes_to_consider": "$inputs.classes_to_consider",
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.consensus.predictions",
}
],
}


WORKFLOW_WITH_LITERAL_LIST = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
],
"steps": [
{
"type": "RoboflowObjectDetectionModel",
"name": "detection",
"image": "$inputs.image",
"model_id": "yolov8n-640",
},
{
"type": "DetectionsConsensus",
"name": "consensus",
"predictions_batches": ["$steps.detection.predictions"],
"required_votes": 1,
"classes_to_consider": ["person"],
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.consensus.predictions",
}
],
}


def test_union_list_selector_with_selector_to_list(
model_manager: ModelManager,
crowd_image: np.ndarray,
) -> None:
"""Test Union[List[str], Selector(...)] when using a selector to a list."""
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={
"image": crowd_image,
"classes_to_consider": ["person"],
}
)

# then
assert isinstance(result, list), "Expected list of results"
assert len(result) == 1, "Expected single result"
assert "predictions" in result[0], "Expected predictions in output"
# Verify that the selector was properly resolved and the workflow executed


def test_union_list_selector_with_literal_list(
model_manager: ModelManager,
crowd_image: np.ndarray,
) -> None:
"""Test Union[List[str], Selector(...)] when using a literal list."""
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_LITERAL_LIST,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={
"image": crowd_image,
}
)

# then
assert isinstance(result, list), "Expected list of results"
assert len(result) == 1, "Expected single result"
assert "predictions" in result[0], "Expected predictions in output"
# Verify that the literal list was properly handled


def test_union_list_selector_validates_type_mismatch(
model_manager: ModelManager,
crowd_image: np.ndarray,
) -> None:
"""Test that type validation catches invalid selector resolution."""
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when/then - passing a string instead of a list should fail validation
with pytest.raises(Exception): # Should raise validation error
execution_engine.run(
runtime_parameters={
"image": crowd_image,
"classes_to_consider": "person", # String instead of list
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from inference.core.workflows.execution_engine.entities.types import (
BOOLEAN_KIND,
IMAGE_KIND,
LIST_OF_VALUES_KIND,
OBJECT_DETECTION_PREDICTION_KIND,
STRING_KIND,
Selector,
StepOutputImageSelector,
StepOutputSelector,
StepSelector,
Expand Down Expand Up @@ -531,3 +533,83 @@ def describe_outputs(cls) -> List[OutputDefinition]:
)
},
)


def test_parse_block_manifest_when_manifest_defines_union_of_list_str_or_selector() -> (
None
):
# Union[List[str], Selector] should have is_list_element=False
# because List[str] doesn't allow selectors inside
class Manifest(WorkflowBlockManifest):
type: Literal["MyManifest"]
name: str = Field(description="name field")
tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field(
description="Tags can be a literal list or a selector to a list"
)

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return []

manifest_metadata = parse_block_manifest(manifest_type=Manifest)

assert manifest_metadata == BlockManifestMetadata(
primitive_types={
"name": PrimitiveTypeDefinition(
property_name="name",
property_description="name field",
type_annotation="str",
),
"tags": PrimitiveTypeDefinition(
property_name="tags",
property_description="Tags can be a literal list or a selector to a list",
type_annotation="List[str]",
),
},
selectors={
"tags": SelectorDefinition(
property_name="tags",
property_description="Tags can be a literal list or a selector to a list",
allowed_references=[
ReferenceDefinition(
selected_element="any_data",
kind=[LIST_OF_VALUES_KIND],
points_to_batch={False},
),
],
is_list_element=False,
is_dict_element=False,
dimensionality_offset=0,
is_dimensionality_reference_property=False,
)
},
)


def test_parse_block_manifest_when_manifest_defines_union_of_list_with_selectors_or_selector() -> (
None
):
# Union[List[Union[Selector, str]], Selector] should have is_list_element=True
# because List items can contain selectors (like DatasetUpload.registration_tags)
class Manifest(WorkflowBlockManifest):
type: Literal["MyManifest"]
name: str = Field(description="name field")
registration_tags: Union[
List[Union[Selector(kind=[STRING_KIND]), str]],
Selector(kind=[LIST_OF_VALUES_KIND]),
] = Field(description="Tags with selectors inside the list")

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return []

manifest_metadata = parse_block_manifest(manifest_type=Manifest)

assert "registration_tags" in manifest_metadata.selectors
selector_def = manifest_metadata.selectors["registration_tags"]
assert selector_def.is_list_element is True
assert selector_def.is_dict_element is False

all_kinds = {k.name for ref in selector_def.allowed_references for k in ref.kind}
assert STRING_KIND.name in all_kinds
assert LIST_OF_VALUES_KIND.name in all_kinds
Loading