Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,19 @@ def retrieve_selectors_from_union_definition(
+ union_definition.get(ONE_OF_KEY, [])
+ union_definition.get(ALL_OF_KEY, [])
)
# Check if any union variant is an array or dict type
# This handles Union[List[...], Selector(...)] patterns
contains_array_type = False
contains_dict_type = False
for type_definition in union_types:
if type_definition.get("type") == "array" and ITEMS_KEY in type_definition:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use constants for key names

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and also values

contains_array_type = True
if (
type_definition.get("type") == "object"
and "additionalProperties" in type_definition
):
contains_dict_type = True

results = []
for type_definition in union_types:
result = retrieve_selectors_from_simple_property(
Expand Down Expand Up @@ -468,8 +481,8 @@ def retrieve_selectors_from_union_definition(
property_name=property_name,
property_description=property_description,
allowed_references=merged_references,
is_list_element=is_list_element,
is_dict_element=is_dict_element,
is_list_element=is_list_element or contains_array_type,
is_dict_element=is_dict_element or contains_dict_type,
dimensionality_offset=property_dimensionality_offset,
is_dimensionality_reference_property=is_dimensionality_reference_property,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,30 @@ def get_step_selectors(
step_manifest=step_manifest,
property_name=property_name,
)
if selector_definition.is_list_element:
# Check runtime type to handle Union[List[...], Selector(...)] patterns
# where the actual value determines which path to take
if selector_definition.is_list_element and isinstance(property_value, list):
selectors = retrieve_selectors_from_array(
step_name=step_manifest.name,
property_value=property_value,
selector_definition=selector_definition,
)
result.extend(selectors)
elif selector_definition.is_dict_element:
elif selector_definition.is_dict_element and isinstance(property_value, dict):
selectors = retrieve_selectors_from_dictionary(
step_name=step_manifest.name,
property_value=property_value,
selector_definition=selector_definition,
)
result.extend(selectors)
else:
elif is_selector(property_value):
selector = retrieve_selector_from_simple_property(
step_name=step_manifest.name,
property_value=property_value,
selector_definition=selector_definition,
)
result.append(selector)
# If none of the above, property_value is not a selector
return [r for r in result if r is not None]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Integration tests for Union[List[...], Selector(...)] patterns in workflow blocks.

This test verifies that blocks can properly handle properties defined as
Union[List[T], Selector(...)] where the value can be either:
- A literal list: ["tag1", "tag2"]
- A selector to a list: $inputs.tags
- A mixed list with literals and selectors: ["literal", "$inputs.tag"]
"""

import numpy as np
import pytest

from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS
from inference.core.managers.base import ModelManager
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.execution_engine.core import ExecutionEngine


WORKFLOW_WITH_SELECTOR_TO_LIST = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
{
"type": "WorkflowParameter",
"name": "classes_to_consider",
},
],
"steps": [
{
"type": "RoboflowObjectDetectionModel",
"name": "detection",
"image": "$inputs.image",
"model_id": "yolov8n-640",
},
{
"type": "DetectionsConsensus",
"name": "consensus",
"predictions_batches": ["$steps.detection.predictions"],
"required_votes": 1,
"classes_to_consider": "$inputs.classes_to_consider",
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.consensus.predictions",
}
],
}


WORKFLOW_WITH_LITERAL_LIST = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
],
"steps": [
{
"type": "RoboflowObjectDetectionModel",
"name": "detection",
"image": "$inputs.image",
"model_id": "yolov8n-640",
},
{
"type": "DetectionsConsensus",
"name": "consensus",
"predictions_batches": ["$steps.detection.predictions"],
"required_votes": 1,
"classes_to_consider": ["person"],
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.consensus.predictions",
}
],
}


def test_union_list_selector_with_selector_to_list(
model_manager: ModelManager,
crowd_image: np.ndarray,
) -> None:
"""Test Union[List[str], Selector(...)] when using a selector to a list."""
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={
"image": crowd_image,
"classes_to_consider": ["person"],
}
)

# then
assert isinstance(result, list), "Expected list of results"
assert len(result) == 1, "Expected single result"
assert "predictions" in result[0], "Expected predictions in output"
# Verify that the selector was properly resolved and the workflow executed


def test_union_list_selector_with_literal_list(
model_manager: ModelManager,
crowd_image: np.ndarray,
) -> None:
"""Test Union[List[str], Selector(...)] when using a literal list."""
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_LITERAL_LIST,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={
"image": crowd_image,
}
)

# then
assert isinstance(result, list), "Expected list of results"
assert len(result) == 1, "Expected single result"
assert "predictions" in result[0], "Expected predictions in output"
# Verify that the literal list was properly handled


def test_union_list_selector_validates_type_mismatch(
model_manager: ModelManager,
crowd_image: np.ndarray,
) -> None:
"""Test that type validation catches invalid selector resolution."""
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when/then - passing a string instead of a list should fail validation
with pytest.raises(Exception): # Should raise validation error
execution_engine.run(
runtime_parameters={
"image": crowd_image,
"classes_to_consider": "person", # String instead of list
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from inference.core.workflows.execution_engine.entities.types import (
BOOLEAN_KIND,
IMAGE_KIND,
LIST_OF_VALUES_KIND,
OBJECT_DETECTION_PREDICTION_KIND,
STRING_KIND,
Selector,
StepOutputImageSelector,
StepOutputSelector,
StepSelector,
Expand Down Expand Up @@ -531,3 +533,61 @@ def describe_outputs(cls) -> List[OutputDefinition]:
)
},
)


def test_parse_block_manifest_when_manifest_defines_union_of_list_or_selector() -> (
None
):
"""Test that Union[List[...], Selector(...)] properly sets is_list_element=True.

This is a regression test for the bug where the schema parser would not detect
that a property could receive a list when defined as Union[List[T], Selector(...)].
"""
# given

class Manifest(WorkflowBlockManifest):
type: Literal["MyManifest"]
name: str = Field(description="name field")
tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field(
description="Tags can be a literal list or a selector to a list"
)

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return []

# when
manifest_metadata = parse_block_manifest(manifest_type=Manifest)

# then
assert manifest_metadata == BlockManifestMetadata(
primitive_types={
"name": PrimitiveTypeDefinition(
property_name="name",
property_description="name field",
type_annotation="str",
),
"tags": PrimitiveTypeDefinition(
property_name="tags",
property_description="Tags can be a literal list or a selector to a list",
type_annotation="List[str]",
),
},
selectors={
"tags": SelectorDefinition(
property_name="tags",
property_description="Tags can be a literal list or a selector to a list",
allowed_references=[
ReferenceDefinition(
selected_element="any_data",
kind=[LIST_OF_VALUES_KIND],
points_to_batch={False},
),
],
is_list_element=True,
is_dict_element=False,
dimensionality_offset=0,
is_dimensionality_reference_property=False,
)
},
)
Loading