Skip to content

Add fixes for step inputs construction#2108

Closed
PawelPeczek-Roboflow wants to merge 1 commit intomainfrom
test/stacked-condition-flow
Closed

Add fixes for step inputs construction#2108
PawelPeczek-Roboflow wants to merge 1 commit intomainfrom
test/stacked-condition-flow

Conversation

@PawelPeczek-Roboflow
Copy link
Collaborator

What does this PR do?

Related Issue(s):

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Refactoring (no functional changes)
  • Other:

Testing

  • I have tested this change locally
  • I have added/updated tests for this change

Test details:

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code where necessary, particularly in hard-to-understand areas
  • My changes generate no new warnings or errors
  • I have updated the documentation accordingly (if applicable)

Additional Context

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Mar 12, 2026

⚡️ Codeflash found optimizations for this PR

📄 19% (0.19x) speedup for Batch.remove_by_indices in inference/core/workflows/execution_engine/entities/base.py

⏱️ Runtime : 1.54 milliseconds 1.29 milliseconds (best of 242 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch test/stacked-condition-flow).

Static Badge

Comment on lines +341 to +358
for dim in reversed(sorted_dims):
for idx in by_dim[dim]:
if dim == sorted_dims[-1] or idx in has_child:
parent = idx[:-1]
if parent:
has_child.add(parent)

# Early exit if intersection becomes empty
if not intersection:
return set()
# Top-down: keep indices only if full prefix chain exists
valid: Dict[int, Set[DynamicBatchIndex]] = {dim: set() for dim in sorted_dims}
for dim in sorted_dims:
for idx in by_dim[dim]:
parent = idx[:-1]
if dim == sorted_dims[0]:
if idx in has_child:
valid[dim].add(idx)
elif parent in valid[prev_dim[dim]]:
if dim == sorted_dims[-1] or idx in has_child:
valid[dim].add(idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 17% (0.17x) speedup for get_masks_intersection_for_dimensions in inference/core/workflows/execution_engine/v1/executor/execution_data_manager/step_input_assembler.py

⏱️ Runtime : 4.44 milliseconds 3.79 milliseconds (best of 31 runs)

📝 Explanation and details

The optimized code hoists invariant dimension checks (dim == last_dim, dim == first_dim) outside the per-index inner loops, eliminating redundant comparisons on every iteration. In the bottom-up pass, separating the last_dim branch avoids checking dim == sorted_dims[-1] or idx in has_child for 2227 leaf indices (profiler shows ~0.9 ms saved on that condition alone). Similarly, in the top-down pass, splitting first_dim and last_dim cases removes ~6000 redundant dim == sorted_dims[0] checks and dictionary lookups of prev_dim[dim] that would otherwise execute inside the loop. Hoisting prev_set = valid[prev_dim[dim]] outside the inner loop further cuts repeated dictionary accesses. These changes reduce CPU cycles in hot paths without altering logic, yielding a 17% runtime improvement (4.44 ms → 3.79 ms) with no correctness regressions.

Correctness verification report:

Test Status
⏪ Replay Tests 🔘 None Found
⚙️ Existing Unit Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
🌀 Generated Regression Tests 40 Passed
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
from typing import Dict, List, Optional, Set

# imports
import pytest  # used for our unit tests
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.dynamic_batches_manager import (
    DynamicBatchIndex,
)

# import the function and the real DynamicBatchIndex class used by the function
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.step_input_assembler import (
    get_masks_intersection_for_dimensions,
)


# helper to convert the returned sets of DynamicBatchIndex into sets of plain tuples
def _to_tuple_map(
    result: Dict[int, Optional[Set[DynamicBatchIndex]]]
) -> Dict[int, Optional[Set[tuple]]]:
    """
    Convert a mapping of int -> Optional[Set[DynamicBatchIndex]] into
    int -> Optional[Set[tuple]] so comparisons are simpler and not dependent
    on the concrete DynamicBatchIndex equality implementation.
    """
    out: Dict[int, Optional[Set[tuple]]] = {}
    for k, v in result.items():
        if v is None:
            out[k] = None
        else:
            out[k] = {tuple(idx) for idx in v}
    return out


def test_empty_batch_masks_returns_none_for_each_dimension():
    # When batch_masks is empty, the function should return None for each requested dimension
    dims = {1, 2, 3}
    result = get_masks_intersection_for_dimensions(
        [], dims
    )  # 1.19μs -> 1.25μs (4.87% slower)
    # Expect exactly {1: None, 2: None, 3: None}
    assert result == {1: None, 2: None, 3: None}


def test_single_dimension_collects_all_indices_of_that_length():
    # For a single dimension, the function returns a set of all indices that have that length
    dims = {1}
    # Create a few DynamicBatchIndex instances of length 1
    i0 = DynamicBatchIndex((0,))
    i1 = DynamicBatchIndex((1,))
    i2 = DynamicBatchIndex((2,))
    # Provide masks as a list of sets (could be multiple masks)
    masks: List[Set[DynamicBatchIndex]] = [{i0, i1}, {i2}]
    result = get_masks_intersection_for_dimensions(
        masks, dims
    )  # 3.46μs -> 3.56μs (2.81% slower)
    # Convert to tuple form for easy comparison
    tuple_map = _to_tuple_map(result)
    assert tuple_map == {
        1: {(0,), (1,), (2,)}
    }  # all three single-length indices are present


def test_hierarchical_chain_includes_full_chains_and_excludes_orphans():
    # This test verifies the "top-down" and "bottom-up" logic:
    # - A full chain (len1 -> len2 -> len3) should be present in all levels
    # - Orphan bottom entries (without parents) should be excluded
    dims = {1, 2, 3}

    # Construct a valid 3-level chain: (0,) -> (0, 1) -> (0, 1, 2)
    a1 = DynamicBatchIndex((0,))
    a2 = DynamicBatchIndex((0, 1))
    a3 = DynamicBatchIndex((0, 1, 2))

    # An orphan bottom index with no corresponding parents
    orphan_bottom = DynamicBatchIndex((9, 9, 9))

    # A middle-level index that lacks the top-level parent (should be excluded)
    orphan_middle = DynamicBatchIndex((7, 8))

    # A top-level index with no children (should be excluded because it isn't in has_child)
    top_no_child = DynamicBatchIndex((5,))

    # Distribute indices across two masks to ensure aggregation works across masks
    mask1 = {a1, a3, orphan_bottom}
    mask2 = {a2, orphan_middle, top_no_child}
    masks = [mask1, mask2]

    result = get_masks_intersection_for_dimensions(
        masks, dims
    )  # 11.9μs -> 11.6μs (2.68% faster)
    mapped = _to_tuple_map(result)

    # Expected:
    # - dim 1: only (0,) should be kept (it has a descendant chain). (5,) has no child and is excluded.
    # - dim 2: only (0,1) should be kept because its parent (0,) exists in valid[1] and it has a descendant
    # - dim 3: only (0,1,2) should be kept. orphan_bottom has no parent chain and should be excluded.
    assert mapped == {
        1: {(0,)},
        2: {(0, 1)},
        3: {(0, 1, 2)},
    }


def test_non_consecutive_dimensions_produces_no_valid_chains():
    # If dimensions skip lengths (e.g., {1, 3}), chain logic should fail because parent lengths won't line up
    dims = {1, 3}
    # Provide a length-3 index and a length-1 index
    idx3 = DynamicBatchIndex((1, 2, 3))
    idx1 = DynamicBatchIndex((1,))
    masks = [{idx3, idx1}]
    result = get_masks_intersection_for_dimensions(
        masks, dims
    )  # 7.40μs -> 7.43μs (0.404% slower)
    mapped = _to_tuple_map(result)
    # No valid chains can be formed because the parent for the length-3 index is length-2 (which is not requested)
    # The algorithm therefore should return empty sets for both requested dimensions (not None)
    assert mapped == {1: set(), 3: set()}


def test_large_scale_many_chains_performance_and_correctness():
    # Build many independent chains of length 3: for i in [0..N-1], create (i,), (i,0), (i,0,0)
    N = 1000  # number of independent chains (reasonable upper bound per instructions)
    dims = {1, 2, 3}

    # Create a small number of masks and distribute indices round-robin across them
    num_masks = 10
    masks: List[Set[DynamicBatchIndex]] = [set() for _ in range(num_masks)]

    # Track expected tuple sets for each dimension
    expected_dim1 = set()
    expected_dim2 = set()
    expected_dim3 = set()

    for i in range(N):
        idx1 = DynamicBatchIndex((i,))  # length 1
        idx2 = DynamicBatchIndex((i, 0))  # length 2
        idx3 = DynamicBatchIndex((i, 0, 0))  # length 3

        # place each index into one of the masks (round-robin)
        masks[(3 * i) % num_masks].add(idx1)
        masks[(3 * i + 1) % num_masks].add(idx2)
        masks[(3 * i + 2) % num_masks].add(idx3)

        # record expected tuples (a full chain exists for every i)
        expected_dim1.add((i,))
        expected_dim2.add((i, 0))
        expected_dim3.add((i, 0, 0))

    # Run the function under test
    result = get_masks_intersection_for_dimensions(
        masks, dims
    )  # 1.67ms -> 1.35ms (23.8% faster)
    mapped = _to_tuple_map(result)

    # Verify sizes first for quick failure if something went very wrong
    assert len(mapped[1]) == N
    assert len(mapped[2]) == N
    assert len(mapped[3]) == N

    # Verify the contents exactly match what we built
    assert mapped == {1: expected_dim1, 2: expected_dim2, 3: expected_dim3}
from typing import Dict, List, Optional, Set

# imports
import pytest
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.dynamic_batches_manager import (
    DynamicBatchIndex,
)

# Import the function to test
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.step_input_assembler import (
    get_masks_intersection_for_dimensions,
)


def test_empty_batch_masks():
    """Test that empty batch_masks returns None for all dimensions."""
    # When batch_masks is empty, all dimensions should map to None
    result = get_masks_intersection_for_dimensions(
        [], {1, 2, 3}
    )  # 1.27μs -> 1.33μs (4.50% slower)
    assert result == {1: None, 2: None, 3: None}
    assert all(v is None for v in result.values())


def test_single_dimension_single_mask():
    """Test with a single dimension and a single mask."""
    # Create a simple mask with one index of length 1
    batch_masks = [{(0,)}]
    dimensions = {1}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 3.22μs -> 3.27μs (1.53% slower)
    # Should return the index since it matches the dimension
    assert result == {1: {(0,)}}


def test_single_dimension_multiple_indices():
    """Test with a single dimension and multiple indices."""
    # Create a mask with multiple indices of length 1
    batch_masks = [{(0,), (1,), (2,)}]
    dimensions = {1}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 3.21μs -> 3.29μs (2.43% slower)
    # All three indices should be included
    assert result == {1: {(0,), (1,), (2,)}}


def test_multiple_dimensions_hierarchical_structure():
    """Test with hierarchical indices (tuples of increasing length)."""
    # Create masks with indices of different lengths forming a hierarchy
    batch_masks = [
        {(0,), (0, 1), (0, 1, 2)},  # indices with lengths 1, 2, 3
    ]
    dimensions = {1, 2, 3}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 10.0μs -> 9.93μs (0.816% faster)
    # Each dimension should contain indices of matching length
    assert 1 in result
    assert 2 in result
    assert 3 in result
    assert (0,) in result[1]
    assert (0, 1) in result[2]
    assert (0, 1, 2) in result[3]


def test_two_dimensions_with_hierarchy():
    """Test intersection logic with two dimensions."""
    # Create a complete parent-child hierarchy
    batch_masks = [{(0,), (0, 1), (1,), (1, 2)}]
    dimensions = {1, 2}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 9.34μs -> 8.88μs (5.21% faster)
    # Check that the hierarchy is correctly identified
    assert result[1] is not None
    assert result[2] is not None


def test_single_dimension_with_multiple_masks():
    """Test combining multiple masks for a single dimension."""
    # Multiple masks each containing different indices
    batch_masks = [{(0,)}, {(1,)}, {(2,)}]
    dimensions = {1}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 3.59μs -> 3.68μs (2.45% slower)
    # All indices from all masks should be combined
    assert (0,) in result[1]
    assert (1,) in result[1]
    assert (2,) in result[1]


def test_no_valid_hierarchy():
    """Test when indices don't form a valid complete hierarchy."""
    # Indices (0,) and (1, 2) don't have a parent-child relationship
    batch_masks = [{(0,), (1, 2)}]
    dimensions = {1, 2}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 8.03μs -> 7.92μs (1.40% faster)
    # The result depends on the has_child logic
    assert isinstance(result, dict)
    assert 1 in result
    assert 2 in result


def test_single_index_multiple_dimensions():
    """Test a single index across multiple dimension levels."""
    # Create a single index hierarchy (0,) -> (0, 1) -> (0, 1, 2)
    batch_masks = [{(0,), (0, 1), (0, 1, 2)}]
    dimensions = {1, 2, 3}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 10.0μs -> 9.80μs (2.36% faster)
    # Should have valid entries at each dimension
    assert result[1] is not None or result[1] is None  # valid state
    assert result[2] is not None or result[2] is None  # valid state
    assert result[3] is not None or result[3] is None  # valid state


def test_unsorted_dimensions():
    """Test that unsorted dimensions are handled correctly."""
    # Provide dimensions in non-sorted order
    batch_masks = [{(0,), (0, 1), (0, 1, 2)}]
    dimensions = {3, 1, 2}  # unsorted input
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 9.59μs -> 9.14μs (4.94% faster)
    # Result should have all dimensions as keys
    assert set(result.keys()) == {1, 2, 3}


def test_large_tuple_indices():
    """Test with very long index tuples (dimension 10)."""
    # Create indices with length 10
    idx = tuple(range(10))
    batch_masks = [{idx}]
    dimensions = {10}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 3.42μs -> 3.43μs (0.292% slower)
    assert 10 in result


def test_many_dimensions():
    """Test with a large number of dimensions (20 levels)."""
    # Create indices for 20 different dimension levels
    idx_chain = []
    for i in range(1, 21):
        idx_chain.append(tuple(range(i)))
    batch_masks = [set(idx_chain)]
    dimensions = set(range(1, 21))
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 40.3μs -> 39.4μs (2.34% faster)
    # Should handle all 20 dimensions
    assert len(result) == 20
    assert all(d in result for d in dimensions)


def test_duplicate_indices_in_mask():
    """Test that duplicate indices in a mask are handled correctly."""
    # Create a mask with duplicate indices (sets handle this automatically)
    batch_masks = [{(0,), (0,), (0,)}]  # duplicates
    dimensions = {1}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 3.10μs -> 3.22μs (3.73% slower)
    # Should only contain the unique index
    assert result[1] == {(0,)}


def test_multiple_masks_with_overlapping_indices():
    """Test multiple masks with overlapping indices."""
    # Multiple masks with some overlapping indices
    batch_masks = [{(0,), (0, 1)}, {(0,), (1,), (1, 2)}]
    dimensions = {1, 2}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 9.49μs -> 9.25μs (2.60% faster)
    # Both masks should be merged
    assert isinstance(result, dict)
    assert all(k in result for k in dimensions)


def test_zero_value_in_tuple():
    """Test indices containing zero values."""
    # Indices with 0 values should be handled correctly
    batch_masks = [{(0,), (0, 0), (0, 0, 0)}]
    dimensions = {1, 2, 3}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 9.73μs -> 9.57μs (1.68% faster)
    # Should not treat 0 specially
    assert isinstance(result, dict)


def test_large_values_in_tuple():
    """Test indices with large integer values."""
    # Use large integers in the tuples
    idx1 = (999,)
    idx2 = (999, 9999)
    batch_masks = [{idx1, idx2}]
    dimensions = {1, 2}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 7.88μs -> 7.48μs (5.22% faster)
    assert isinstance(result, dict)


def test_nested_empty_sets_in_masks():
    """Test when a mask is an empty set."""
    # One mask is empty, others have content
    batch_masks = [set(), {(0,)}, {(0, 1)}]
    dimensions = {1, 2}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 8.54μs -> 8.32μs (2.53% faster)
    assert isinstance(result, dict)


def test_large_number_of_indices_in_single_mask():
    """Test performance with many indices in a single mask."""
    # Create 1000 indices of length 1
    large_mask = {(i,) for i in range(1000)}
    batch_masks = [large_mask]
    dimensions = {1}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 69.7μs -> 68.2μs (2.29% faster)
    # Should handle all 1000 indices
    assert len(result[1]) == 1000


def test_large_number_of_masks():
    """Test performance with many masks."""
    # Create 1000 different masks, each with a few indices
    batch_masks = [{(i,), (i, 0)} for i in range(1000)]
    dimensions = {1, 2}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 1.00ms -> 774μs (29.7% faster)
    # All unique indices should be aggregated
    assert 1 in result
    assert 2 in result


def test_deep_hierarchy_chain():
    """Test with a deep hierarchy of indices."""
    # Create a single chain of indices: (0,) -> (0,0) -> (0,0,0) -> ... (0,0,...,0) with 100 zeros
    deep_idx = tuple([0] * 50)
    indices = set()
    for i in range(1, 51):
        indices.add(tuple([0] * i))
    batch_masks = [indices]
    dimensions = set(range(1, 51))
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 103μs -> 107μs (3.14% slower)
    # Should handle the deep hierarchy
    assert len(result) == 50


def test_wide_hierarchy_many_siblings():
    """Test with many sibling indices at each level."""
    # Create a wide tree: at level 1: (0,) to (999,); at level 2: (0,0) to (999,999), etc.
    batch_masks = []
    mask = set()
    # Add 100 indices at dimension 1
    for i in range(100):
        mask.add((i,))
    # Add parent-child pairs at dimension 2
    for i in range(100):
        mask.add((i, 0))
    batch_masks = [mask]
    dimensions = {1, 2}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 103μs -> 80.9μs (28.3% faster)
    assert 1 in result
    assert 2 in result


def test_many_separate_hierarchies():
    """Test multiple independent hierarchies in different masks."""
    # Create 50 independent hierarchies (parent-child chains)
    batch_masks = []
    for h in range(50):
        hierarchy = {(h,), (h, 0), (h, 0, 0)}
        batch_masks.append(hierarchy)
    dimensions = {1, 2, 3}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 95.8μs -> 78.2μs (22.5% faster)
    # Should merge all hierarchies
    assert 1 in result
    assert 2 in result
    assert 3 in result


def test_complex_multi_branch_hierarchy():
    """Test a complex tree with multiple branches."""
    # Create a tree structure with branches
    mask = set()
    # Root level
    mask.add((0,))
    # Level 2: (0,0) and (0,1)
    mask.add((0, 0))
    mask.add((0, 1))
    # Level 3: children of (0,0)
    mask.add((0, 0, 0))
    mask.add((0, 0, 1))
    # Level 3: children of (0,1)
    mask.add((0, 1, 0))
    mask.add((0, 1, 1))
    # Level 4: grandchildren
    mask.add((0, 0, 0, 0))
    mask.add((0, 1, 1, 0))
    batch_masks = [mask]
    dimensions = {1, 2, 3, 4}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 14.3μs -> 13.5μs (5.77% faster)
    assert len(result) == 4


def test_very_wide_dimensions():
    """Test with very large index values across wide branching."""
    # Create indices with large values spread across many positions
    mask = set()
    for i in range(100):
        for j in range(100):
            if i < 100 and j < 100:
                mask.add((i, j))
    batch_masks = [mask]
    dimensions = {2}
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 776μs -> 768μs (1.01% faster)
    # Should handle up to 10000 indices at dimension 2
    assert 2 in result


def test_stress_test_combined_complexity():
    """Stress test with large scale, deep hierarchy, and many masks."""
    # Combine multiple sources of complexity
    batch_masks = []
    # 10 different masks
    for m in range(10):
        mask = set()
        # Each mask has a hierarchy up to depth 10 with multiple branches
        for branch in range(5):
            base = (m * 5 + branch,)
            mask.add(base)
            for depth in range(2, 11):
                idx = base + tuple([0] * (depth - 1))
                mask.add(idx)
        batch_masks.append(mask)
    dimensions = set(range(1, 11))
    result = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 378μs -> 327μs (15.6% faster)
    # Should complete without error and return valid structure
    assert len(result) == 10
    assert all(isinstance(v, (set, type(None))) for v in result.values())


def test_return_type_consistency():
    """Test that return type is always Dict[int, Optional[Set[DynamicBatchIndex]]]."""
    # Various inputs should all return the correct type
    test_cases = [
        ([], {1}),
        ([{(0,)}], {1}),
        ([{(0,), (0, 1)}], {1, 2}),
    ]
    for batch_masks, dimensions in test_cases:
        result = get_masks_intersection_for_dimensions(
            batch_masks, dimensions
        )  # 10.6μs -> 10.6μs (0.085% slower)
        assert isinstance(result, dict)
        for k, v in result.items():
            assert isinstance(k, int)
            assert v is None or isinstance(v, set)
            if isinstance(v, set):
                # Each element should be a tuple (DynamicBatchIndex)
                for elem in v:
                    assert isinstance(elem, tuple)


def test_idempotent_operation():
    """Test that calling the function twice with same inputs gives same results."""
    # Run the same operation twice
    batch_masks = [{(0,), (0, 1), (0, 1, 2)}]
    dimensions = {1, 2, 3}
    result1 = get_masks_intersection_for_dimensions(
        batch_masks, dimensions
    )  # 9.62μs -> 9.29μs (3.55% faster)
    result2 = get_masks_intersection_for_dimensions(batch_masks, dimensions)
    # Results should be identical
    assert result1 == result2  # 5.98μs -> 5.73μs (4.36% faster)


def test_order_independence_of_dimensions():
    """Test that dimension order doesn't affect results."""
    # Test with different orderings of the same dimensions
    batch_masks = [{(0,), (0, 1), (0, 1, 2)}]
    result1 = get_masks_intersection_for_dimensions(
        batch_masks, {1, 2, 3}
    )  # 9.25μs -> 8.93μs (3.58% faster)
    result2 = get_masks_intersection_for_dimensions(batch_masks, {3, 1, 2})
    result3 = get_masks_intersection_for_dimensions(
        batch_masks, {2, 3, 1}
    )  # 5.92μs -> 5.53μs (7.05% faster)
    # All three should give the same result
    assert result1 == result2
    assert result2 == result3  # 4.70μs -> 4.55μs (3.30% faster)


def test_order_independence_of_mask_order():
    """Test that the order of masks in batch_masks doesn't matter."""
    # Create same masks in different orders
    batch_masks1 = [{(0,)}, {(0, 1)}, {(0, 1, 2)}]
    batch_masks2 = [{(0, 1, 2)}, {(0,)}, {(0, 1)}]
    batch_masks3 = [{(0, 1)}, {(0, 1, 2)}, {(0,)}]
    dimensions = {1, 2, 3}
    result1 = get_masks_intersection_for_dimensions(
        batch_masks1, dimensions
    )  # 9.06μs -> 9.15μs (0.984% slower)
    result2 = get_masks_intersection_for_dimensions(batch_masks2, dimensions)
    result3 = get_masks_intersection_for_dimensions(
        batch_masks3, dimensions
    )  # 5.66μs -> 5.73μs (1.22% slower)
    # All should produce the same result
    assert result1 == result2
    assert result2 == result3  # 4.92μs -> 4.73μs (4.02% faster)

To test or edit this optimization locally git merge codeflash/optimize-pr2108-2026-03-12T23.38.09

Click to see suggested changes
Suggested change
for dim in reversed(sorted_dims):
for idx in by_dim[dim]:
if dim == sorted_dims[-1] or idx in has_child:
parent = idx[:-1]
if parent:
has_child.add(parent)
# Early exit if intersection becomes empty
if not intersection:
return set()
# Top-down: keep indices only if full prefix chain exists
valid: Dict[int, Set[DynamicBatchIndex]] = {dim: set() for dim in sorted_dims}
for dim in sorted_dims:
for idx in by_dim[dim]:
parent = idx[:-1]
if dim == sorted_dims[0]:
if idx in has_child:
valid[dim].add(idx)
elif parent in valid[prev_dim[dim]]:
if dim == sorted_dims[-1] or idx in has_child:
valid[dim].add(idx)
last_dim = sorted_dims[-1]
for dim in reversed(sorted_dims):
if dim == last_dim:
for idx in by_dim[dim]:
parent = idx[:-1]
if parent:
has_child.add(parent)
else:
for idx in by_dim[dim]:
if idx in has_child:
parent = idx[:-1]
if parent:
has_child.add(parent)
# Top-down: keep indices only if full prefix chain exists
valid: Dict[int, Set[DynamicBatchIndex]] = {dim: set() for dim in sorted_dims}
first_dim = sorted_dims[0]
for dim in sorted_dims:
if dim == first_dim:
for idx in by_dim[dim]:
if idx in has_child:
valid[dim].add(idx)
else:
prev_set = valid[prev_dim[dim]]
if dim == last_dim:
for idx in by_dim[dim]:
parent = idx[:-1]
if parent in prev_set:
valid[dim].add(idx)
else:
for idx in by_dim[dim]:
parent = idx[:-1]
if parent in prev_set and idx in has_child:
valid[dim].add(idx)

Static Badge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant