Add fixes for step inputs construction#2108
Add fixes for step inputs construction#2108PawelPeczek-Roboflow wants to merge 1 commit intomainfrom
Conversation
⚡️ Codeflash found optimizations for this PR📄 19% (0.19x) speedup for
|
| for dim in reversed(sorted_dims): | ||
| for idx in by_dim[dim]: | ||
| if dim == sorted_dims[-1] or idx in has_child: | ||
| parent = idx[:-1] | ||
| if parent: | ||
| has_child.add(parent) | ||
|
|
||
| # Early exit if intersection becomes empty | ||
| if not intersection: | ||
| return set() | ||
| # Top-down: keep indices only if full prefix chain exists | ||
| valid: Dict[int, Set[DynamicBatchIndex]] = {dim: set() for dim in sorted_dims} | ||
| for dim in sorted_dims: | ||
| for idx in by_dim[dim]: | ||
| parent = idx[:-1] | ||
| if dim == sorted_dims[0]: | ||
| if idx in has_child: | ||
| valid[dim].add(idx) | ||
| elif parent in valid[prev_dim[dim]]: | ||
| if dim == sorted_dims[-1] or idx in has_child: | ||
| valid[dim].add(idx) |
There was a problem hiding this comment.
⚡️Codeflash found 17% (0.17x) speedup for get_masks_intersection_for_dimensions in inference/core/workflows/execution_engine/v1/executor/execution_data_manager/step_input_assembler.py
⏱️ Runtime : 4.44 milliseconds → 3.79 milliseconds (best of 31 runs)
📝 Explanation and details
The optimized code hoists invariant dimension checks (dim == last_dim, dim == first_dim) outside the per-index inner loops, eliminating redundant comparisons on every iteration. In the bottom-up pass, separating the last_dim branch avoids checking dim == sorted_dims[-1] or idx in has_child for 2227 leaf indices (profiler shows ~0.9 ms saved on that condition alone). Similarly, in the top-down pass, splitting first_dim and last_dim cases removes ~6000 redundant dim == sorted_dims[0] checks and dictionary lookups of prev_dim[dim] that would otherwise execute inside the loop. Hoisting prev_set = valid[prev_dim[dim]] outside the inner loop further cuts repeated dictionary accesses. These changes reduce CPU cycles in hot paths without altering logic, yielding a 17% runtime improvement (4.44 ms → 3.79 ms) with no correctness regressions.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⏪ Replay Tests | 🔘 None Found |
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 40 Passed |
| 📊 Tests Coverage | 100.0% |
🌀 Click to see Generated Regression Tests
from typing import Dict, List, Optional, Set
# imports
import pytest # used for our unit tests
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.dynamic_batches_manager import (
DynamicBatchIndex,
)
# import the function and the real DynamicBatchIndex class used by the function
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.step_input_assembler import (
get_masks_intersection_for_dimensions,
)
# helper to convert the returned sets of DynamicBatchIndex into sets of plain tuples
def _to_tuple_map(
result: Dict[int, Optional[Set[DynamicBatchIndex]]]
) -> Dict[int, Optional[Set[tuple]]]:
"""
Convert a mapping of int -> Optional[Set[DynamicBatchIndex]] into
int -> Optional[Set[tuple]] so comparisons are simpler and not dependent
on the concrete DynamicBatchIndex equality implementation.
"""
out: Dict[int, Optional[Set[tuple]]] = {}
for k, v in result.items():
if v is None:
out[k] = None
else:
out[k] = {tuple(idx) for idx in v}
return out
def test_empty_batch_masks_returns_none_for_each_dimension():
# When batch_masks is empty, the function should return None for each requested dimension
dims = {1, 2, 3}
result = get_masks_intersection_for_dimensions(
[], dims
) # 1.19μs -> 1.25μs (4.87% slower)
# Expect exactly {1: None, 2: None, 3: None}
assert result == {1: None, 2: None, 3: None}
def test_single_dimension_collects_all_indices_of_that_length():
# For a single dimension, the function returns a set of all indices that have that length
dims = {1}
# Create a few DynamicBatchIndex instances of length 1
i0 = DynamicBatchIndex((0,))
i1 = DynamicBatchIndex((1,))
i2 = DynamicBatchIndex((2,))
# Provide masks as a list of sets (could be multiple masks)
masks: List[Set[DynamicBatchIndex]] = [{i0, i1}, {i2}]
result = get_masks_intersection_for_dimensions(
masks, dims
) # 3.46μs -> 3.56μs (2.81% slower)
# Convert to tuple form for easy comparison
tuple_map = _to_tuple_map(result)
assert tuple_map == {
1: {(0,), (1,), (2,)}
} # all three single-length indices are present
def test_hierarchical_chain_includes_full_chains_and_excludes_orphans():
# This test verifies the "top-down" and "bottom-up" logic:
# - A full chain (len1 -> len2 -> len3) should be present in all levels
# - Orphan bottom entries (without parents) should be excluded
dims = {1, 2, 3}
# Construct a valid 3-level chain: (0,) -> (0, 1) -> (0, 1, 2)
a1 = DynamicBatchIndex((0,))
a2 = DynamicBatchIndex((0, 1))
a3 = DynamicBatchIndex((0, 1, 2))
# An orphan bottom index with no corresponding parents
orphan_bottom = DynamicBatchIndex((9, 9, 9))
# A middle-level index that lacks the top-level parent (should be excluded)
orphan_middle = DynamicBatchIndex((7, 8))
# A top-level index with no children (should be excluded because it isn't in has_child)
top_no_child = DynamicBatchIndex((5,))
# Distribute indices across two masks to ensure aggregation works across masks
mask1 = {a1, a3, orphan_bottom}
mask2 = {a2, orphan_middle, top_no_child}
masks = [mask1, mask2]
result = get_masks_intersection_for_dimensions(
masks, dims
) # 11.9μs -> 11.6μs (2.68% faster)
mapped = _to_tuple_map(result)
# Expected:
# - dim 1: only (0,) should be kept (it has a descendant chain). (5,) has no child and is excluded.
# - dim 2: only (0,1) should be kept because its parent (0,) exists in valid[1] and it has a descendant
# - dim 3: only (0,1,2) should be kept. orphan_bottom has no parent chain and should be excluded.
assert mapped == {
1: {(0,)},
2: {(0, 1)},
3: {(0, 1, 2)},
}
def test_non_consecutive_dimensions_produces_no_valid_chains():
# If dimensions skip lengths (e.g., {1, 3}), chain logic should fail because parent lengths won't line up
dims = {1, 3}
# Provide a length-3 index and a length-1 index
idx3 = DynamicBatchIndex((1, 2, 3))
idx1 = DynamicBatchIndex((1,))
masks = [{idx3, idx1}]
result = get_masks_intersection_for_dimensions(
masks, dims
) # 7.40μs -> 7.43μs (0.404% slower)
mapped = _to_tuple_map(result)
# No valid chains can be formed because the parent for the length-3 index is length-2 (which is not requested)
# The algorithm therefore should return empty sets for both requested dimensions (not None)
assert mapped == {1: set(), 3: set()}
def test_large_scale_many_chains_performance_and_correctness():
# Build many independent chains of length 3: for i in [0..N-1], create (i,), (i,0), (i,0,0)
N = 1000 # number of independent chains (reasonable upper bound per instructions)
dims = {1, 2, 3}
# Create a small number of masks and distribute indices round-robin across them
num_masks = 10
masks: List[Set[DynamicBatchIndex]] = [set() for _ in range(num_masks)]
# Track expected tuple sets for each dimension
expected_dim1 = set()
expected_dim2 = set()
expected_dim3 = set()
for i in range(N):
idx1 = DynamicBatchIndex((i,)) # length 1
idx2 = DynamicBatchIndex((i, 0)) # length 2
idx3 = DynamicBatchIndex((i, 0, 0)) # length 3
# place each index into one of the masks (round-robin)
masks[(3 * i) % num_masks].add(idx1)
masks[(3 * i + 1) % num_masks].add(idx2)
masks[(3 * i + 2) % num_masks].add(idx3)
# record expected tuples (a full chain exists for every i)
expected_dim1.add((i,))
expected_dim2.add((i, 0))
expected_dim3.add((i, 0, 0))
# Run the function under test
result = get_masks_intersection_for_dimensions(
masks, dims
) # 1.67ms -> 1.35ms (23.8% faster)
mapped = _to_tuple_map(result)
# Verify sizes first for quick failure if something went very wrong
assert len(mapped[1]) == N
assert len(mapped[2]) == N
assert len(mapped[3]) == N
# Verify the contents exactly match what we built
assert mapped == {1: expected_dim1, 2: expected_dim2, 3: expected_dim3}from typing import Dict, List, Optional, Set
# imports
import pytest
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.dynamic_batches_manager import (
DynamicBatchIndex,
)
# Import the function to test
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.step_input_assembler import (
get_masks_intersection_for_dimensions,
)
def test_empty_batch_masks():
"""Test that empty batch_masks returns None for all dimensions."""
# When batch_masks is empty, all dimensions should map to None
result = get_masks_intersection_for_dimensions(
[], {1, 2, 3}
) # 1.27μs -> 1.33μs (4.50% slower)
assert result == {1: None, 2: None, 3: None}
assert all(v is None for v in result.values())
def test_single_dimension_single_mask():
"""Test with a single dimension and a single mask."""
# Create a simple mask with one index of length 1
batch_masks = [{(0,)}]
dimensions = {1}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 3.22μs -> 3.27μs (1.53% slower)
# Should return the index since it matches the dimension
assert result == {1: {(0,)}}
def test_single_dimension_multiple_indices():
"""Test with a single dimension and multiple indices."""
# Create a mask with multiple indices of length 1
batch_masks = [{(0,), (1,), (2,)}]
dimensions = {1}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 3.21μs -> 3.29μs (2.43% slower)
# All three indices should be included
assert result == {1: {(0,), (1,), (2,)}}
def test_multiple_dimensions_hierarchical_structure():
"""Test with hierarchical indices (tuples of increasing length)."""
# Create masks with indices of different lengths forming a hierarchy
batch_masks = [
{(0,), (0, 1), (0, 1, 2)}, # indices with lengths 1, 2, 3
]
dimensions = {1, 2, 3}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 10.0μs -> 9.93μs (0.816% faster)
# Each dimension should contain indices of matching length
assert 1 in result
assert 2 in result
assert 3 in result
assert (0,) in result[1]
assert (0, 1) in result[2]
assert (0, 1, 2) in result[3]
def test_two_dimensions_with_hierarchy():
"""Test intersection logic with two dimensions."""
# Create a complete parent-child hierarchy
batch_masks = [{(0,), (0, 1), (1,), (1, 2)}]
dimensions = {1, 2}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 9.34μs -> 8.88μs (5.21% faster)
# Check that the hierarchy is correctly identified
assert result[1] is not None
assert result[2] is not None
def test_single_dimension_with_multiple_masks():
"""Test combining multiple masks for a single dimension."""
# Multiple masks each containing different indices
batch_masks = [{(0,)}, {(1,)}, {(2,)}]
dimensions = {1}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 3.59μs -> 3.68μs (2.45% slower)
# All indices from all masks should be combined
assert (0,) in result[1]
assert (1,) in result[1]
assert (2,) in result[1]
def test_no_valid_hierarchy():
"""Test when indices don't form a valid complete hierarchy."""
# Indices (0,) and (1, 2) don't have a parent-child relationship
batch_masks = [{(0,), (1, 2)}]
dimensions = {1, 2}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 8.03μs -> 7.92μs (1.40% faster)
# The result depends on the has_child logic
assert isinstance(result, dict)
assert 1 in result
assert 2 in result
def test_single_index_multiple_dimensions():
"""Test a single index across multiple dimension levels."""
# Create a single index hierarchy (0,) -> (0, 1) -> (0, 1, 2)
batch_masks = [{(0,), (0, 1), (0, 1, 2)}]
dimensions = {1, 2, 3}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 10.0μs -> 9.80μs (2.36% faster)
# Should have valid entries at each dimension
assert result[1] is not None or result[1] is None # valid state
assert result[2] is not None or result[2] is None # valid state
assert result[3] is not None or result[3] is None # valid state
def test_unsorted_dimensions():
"""Test that unsorted dimensions are handled correctly."""
# Provide dimensions in non-sorted order
batch_masks = [{(0,), (0, 1), (0, 1, 2)}]
dimensions = {3, 1, 2} # unsorted input
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 9.59μs -> 9.14μs (4.94% faster)
# Result should have all dimensions as keys
assert set(result.keys()) == {1, 2, 3}
def test_large_tuple_indices():
"""Test with very long index tuples (dimension 10)."""
# Create indices with length 10
idx = tuple(range(10))
batch_masks = [{idx}]
dimensions = {10}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 3.42μs -> 3.43μs (0.292% slower)
assert 10 in result
def test_many_dimensions():
"""Test with a large number of dimensions (20 levels)."""
# Create indices for 20 different dimension levels
idx_chain = []
for i in range(1, 21):
idx_chain.append(tuple(range(i)))
batch_masks = [set(idx_chain)]
dimensions = set(range(1, 21))
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 40.3μs -> 39.4μs (2.34% faster)
# Should handle all 20 dimensions
assert len(result) == 20
assert all(d in result for d in dimensions)
def test_duplicate_indices_in_mask():
"""Test that duplicate indices in a mask are handled correctly."""
# Create a mask with duplicate indices (sets handle this automatically)
batch_masks = [{(0,), (0,), (0,)}] # duplicates
dimensions = {1}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 3.10μs -> 3.22μs (3.73% slower)
# Should only contain the unique index
assert result[1] == {(0,)}
def test_multiple_masks_with_overlapping_indices():
"""Test multiple masks with overlapping indices."""
# Multiple masks with some overlapping indices
batch_masks = [{(0,), (0, 1)}, {(0,), (1,), (1, 2)}]
dimensions = {1, 2}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 9.49μs -> 9.25μs (2.60% faster)
# Both masks should be merged
assert isinstance(result, dict)
assert all(k in result for k in dimensions)
def test_zero_value_in_tuple():
"""Test indices containing zero values."""
# Indices with 0 values should be handled correctly
batch_masks = [{(0,), (0, 0), (0, 0, 0)}]
dimensions = {1, 2, 3}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 9.73μs -> 9.57μs (1.68% faster)
# Should not treat 0 specially
assert isinstance(result, dict)
def test_large_values_in_tuple():
"""Test indices with large integer values."""
# Use large integers in the tuples
idx1 = (999,)
idx2 = (999, 9999)
batch_masks = [{idx1, idx2}]
dimensions = {1, 2}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 7.88μs -> 7.48μs (5.22% faster)
assert isinstance(result, dict)
def test_nested_empty_sets_in_masks():
"""Test when a mask is an empty set."""
# One mask is empty, others have content
batch_masks = [set(), {(0,)}, {(0, 1)}]
dimensions = {1, 2}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 8.54μs -> 8.32μs (2.53% faster)
assert isinstance(result, dict)
def test_large_number_of_indices_in_single_mask():
"""Test performance with many indices in a single mask."""
# Create 1000 indices of length 1
large_mask = {(i,) for i in range(1000)}
batch_masks = [large_mask]
dimensions = {1}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 69.7μs -> 68.2μs (2.29% faster)
# Should handle all 1000 indices
assert len(result[1]) == 1000
def test_large_number_of_masks():
"""Test performance with many masks."""
# Create 1000 different masks, each with a few indices
batch_masks = [{(i,), (i, 0)} for i in range(1000)]
dimensions = {1, 2}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 1.00ms -> 774μs (29.7% faster)
# All unique indices should be aggregated
assert 1 in result
assert 2 in result
def test_deep_hierarchy_chain():
"""Test with a deep hierarchy of indices."""
# Create a single chain of indices: (0,) -> (0,0) -> (0,0,0) -> ... (0,0,...,0) with 100 zeros
deep_idx = tuple([0] * 50)
indices = set()
for i in range(1, 51):
indices.add(tuple([0] * i))
batch_masks = [indices]
dimensions = set(range(1, 51))
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 103μs -> 107μs (3.14% slower)
# Should handle the deep hierarchy
assert len(result) == 50
def test_wide_hierarchy_many_siblings():
"""Test with many sibling indices at each level."""
# Create a wide tree: at level 1: (0,) to (999,); at level 2: (0,0) to (999,999), etc.
batch_masks = []
mask = set()
# Add 100 indices at dimension 1
for i in range(100):
mask.add((i,))
# Add parent-child pairs at dimension 2
for i in range(100):
mask.add((i, 0))
batch_masks = [mask]
dimensions = {1, 2}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 103μs -> 80.9μs (28.3% faster)
assert 1 in result
assert 2 in result
def test_many_separate_hierarchies():
"""Test multiple independent hierarchies in different masks."""
# Create 50 independent hierarchies (parent-child chains)
batch_masks = []
for h in range(50):
hierarchy = {(h,), (h, 0), (h, 0, 0)}
batch_masks.append(hierarchy)
dimensions = {1, 2, 3}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 95.8μs -> 78.2μs (22.5% faster)
# Should merge all hierarchies
assert 1 in result
assert 2 in result
assert 3 in result
def test_complex_multi_branch_hierarchy():
"""Test a complex tree with multiple branches."""
# Create a tree structure with branches
mask = set()
# Root level
mask.add((0,))
# Level 2: (0,0) and (0,1)
mask.add((0, 0))
mask.add((0, 1))
# Level 3: children of (0,0)
mask.add((0, 0, 0))
mask.add((0, 0, 1))
# Level 3: children of (0,1)
mask.add((0, 1, 0))
mask.add((0, 1, 1))
# Level 4: grandchildren
mask.add((0, 0, 0, 0))
mask.add((0, 1, 1, 0))
batch_masks = [mask]
dimensions = {1, 2, 3, 4}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 14.3μs -> 13.5μs (5.77% faster)
assert len(result) == 4
def test_very_wide_dimensions():
"""Test with very large index values across wide branching."""
# Create indices with large values spread across many positions
mask = set()
for i in range(100):
for j in range(100):
if i < 100 and j < 100:
mask.add((i, j))
batch_masks = [mask]
dimensions = {2}
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 776μs -> 768μs (1.01% faster)
# Should handle up to 10000 indices at dimension 2
assert 2 in result
def test_stress_test_combined_complexity():
"""Stress test with large scale, deep hierarchy, and many masks."""
# Combine multiple sources of complexity
batch_masks = []
# 10 different masks
for m in range(10):
mask = set()
# Each mask has a hierarchy up to depth 10 with multiple branches
for branch in range(5):
base = (m * 5 + branch,)
mask.add(base)
for depth in range(2, 11):
idx = base + tuple([0] * (depth - 1))
mask.add(idx)
batch_masks.append(mask)
dimensions = set(range(1, 11))
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 378μs -> 327μs (15.6% faster)
# Should complete without error and return valid structure
assert len(result) == 10
assert all(isinstance(v, (set, type(None))) for v in result.values())
def test_return_type_consistency():
"""Test that return type is always Dict[int, Optional[Set[DynamicBatchIndex]]]."""
# Various inputs should all return the correct type
test_cases = [
([], {1}),
([{(0,)}], {1}),
([{(0,), (0, 1)}], {1, 2}),
]
for batch_masks, dimensions in test_cases:
result = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 10.6μs -> 10.6μs (0.085% slower)
assert isinstance(result, dict)
for k, v in result.items():
assert isinstance(k, int)
assert v is None or isinstance(v, set)
if isinstance(v, set):
# Each element should be a tuple (DynamicBatchIndex)
for elem in v:
assert isinstance(elem, tuple)
def test_idempotent_operation():
"""Test that calling the function twice with same inputs gives same results."""
# Run the same operation twice
batch_masks = [{(0,), (0, 1), (0, 1, 2)}]
dimensions = {1, 2, 3}
result1 = get_masks_intersection_for_dimensions(
batch_masks, dimensions
) # 9.62μs -> 9.29μs (3.55% faster)
result2 = get_masks_intersection_for_dimensions(batch_masks, dimensions)
# Results should be identical
assert result1 == result2 # 5.98μs -> 5.73μs (4.36% faster)
def test_order_independence_of_dimensions():
"""Test that dimension order doesn't affect results."""
# Test with different orderings of the same dimensions
batch_masks = [{(0,), (0, 1), (0, 1, 2)}]
result1 = get_masks_intersection_for_dimensions(
batch_masks, {1, 2, 3}
) # 9.25μs -> 8.93μs (3.58% faster)
result2 = get_masks_intersection_for_dimensions(batch_masks, {3, 1, 2})
result3 = get_masks_intersection_for_dimensions(
batch_masks, {2, 3, 1}
) # 5.92μs -> 5.53μs (7.05% faster)
# All three should give the same result
assert result1 == result2
assert result2 == result3 # 4.70μs -> 4.55μs (3.30% faster)
def test_order_independence_of_mask_order():
"""Test that the order of masks in batch_masks doesn't matter."""
# Create same masks in different orders
batch_masks1 = [{(0,)}, {(0, 1)}, {(0, 1, 2)}]
batch_masks2 = [{(0, 1, 2)}, {(0,)}, {(0, 1)}]
batch_masks3 = [{(0, 1)}, {(0, 1, 2)}, {(0,)}]
dimensions = {1, 2, 3}
result1 = get_masks_intersection_for_dimensions(
batch_masks1, dimensions
) # 9.06μs -> 9.15μs (0.984% slower)
result2 = get_masks_intersection_for_dimensions(batch_masks2, dimensions)
result3 = get_masks_intersection_for_dimensions(
batch_masks3, dimensions
) # 5.66μs -> 5.73μs (1.22% slower)
# All should produce the same result
assert result1 == result2
assert result2 == result3 # 4.92μs -> 4.73μs (4.02% faster)To test or edit this optimization locally git merge codeflash/optimize-pr2108-2026-03-12T23.38.09
Click to see suggested changes
| for dim in reversed(sorted_dims): | |
| for idx in by_dim[dim]: | |
| if dim == sorted_dims[-1] or idx in has_child: | |
| parent = idx[:-1] | |
| if parent: | |
| has_child.add(parent) | |
| # Early exit if intersection becomes empty | |
| if not intersection: | |
| return set() | |
| # Top-down: keep indices only if full prefix chain exists | |
| valid: Dict[int, Set[DynamicBatchIndex]] = {dim: set() for dim in sorted_dims} | |
| for dim in sorted_dims: | |
| for idx in by_dim[dim]: | |
| parent = idx[:-1] | |
| if dim == sorted_dims[0]: | |
| if idx in has_child: | |
| valid[dim].add(idx) | |
| elif parent in valid[prev_dim[dim]]: | |
| if dim == sorted_dims[-1] or idx in has_child: | |
| valid[dim].add(idx) | |
| last_dim = sorted_dims[-1] | |
| for dim in reversed(sorted_dims): | |
| if dim == last_dim: | |
| for idx in by_dim[dim]: | |
| parent = idx[:-1] | |
| if parent: | |
| has_child.add(parent) | |
| else: | |
| for idx in by_dim[dim]: | |
| if idx in has_child: | |
| parent = idx[:-1] | |
| if parent: | |
| has_child.add(parent) | |
| # Top-down: keep indices only if full prefix chain exists | |
| valid: Dict[int, Set[DynamicBatchIndex]] = {dim: set() for dim in sorted_dims} | |
| first_dim = sorted_dims[0] | |
| for dim in sorted_dims: | |
| if dim == first_dim: | |
| for idx in by_dim[dim]: | |
| if idx in has_child: | |
| valid[dim].add(idx) | |
| else: | |
| prev_set = valid[prev_dim[dim]] | |
| if dim == last_dim: | |
| for idx in by_dim[dim]: | |
| parent = idx[:-1] | |
| if parent in prev_set: | |
| valid[dim].add(idx) | |
| else: | |
| for idx in by_dim[dim]: | |
| parent = idx[:-1] | |
| if parent in prev_set and idx in has_child: | |
| valid[dim].add(idx) |
What does this PR do?
Related Issue(s):
Type of Change
Testing
Test details:
Checklist
Additional Context