Skip to content

Commit b9618c9

Browse files
shinkpytorchmergebot
authored andcommitted
[Dynamo] Add itertools.compress() support (pytorch#139061)
Use polyfill to add `itertools.compress()` support in Dynamo. Pull Request resolved: pytorch#139061 Approved by: https://github.com/jansel
1 parent e201460 commit b9618c9

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

test/dynamo/test_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,22 @@ def test_itertools_pairwise(a):
326326
pairs.append(torch.ones(size))
327327
return pairs
328328

329+
def test_itertools_compress(self):
330+
def fn():
331+
return itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1])
332+
333+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
334+
self.assertListEqual(list(opt_fn()), list(fn()))
335+
336+
def test_itertools_compress_tensors(self):
337+
def fn():
338+
return itertools.compress(
339+
[torch.tensor([0]), torch.tensor([1]), torch.tensor([2])], [1, 0, 1]
340+
)
341+
342+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
343+
self.assertListEqual(list(opt_fn()), list(fn()))
344+
329345
@make_test
330346
def test_np_iinfo(a):
331347
max_dim = np.iinfo(np.int16).max

torch/_dynamo/polyfills/itertools.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import itertools
88
import sys
9-
from typing import Iterable, Iterator, TypeVar
9+
from typing import Generator, Iterable, Iterator, TypeVar
1010

1111
from ..decorators import substitute_in_graph
1212

@@ -16,10 +16,12 @@
1616
"chain_from_iterable",
1717
"islice",
1818
"tee",
19+
"compress",
1920
]
2021

2122

2223
_T = TypeVar("_T")
24+
_U = TypeVar("_U")
2325

2426

2527
# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
@@ -101,3 +103,11 @@ def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def]
101103
return
102104

103105
return tuple(_tee(shared_link) for _ in range(n))
106+
107+
108+
# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress
109+
@substitute_in_graph(itertools.compress, is_embedded_type=True) # type: ignore[arg-type]
110+
def compress(
111+
data: Iterable[_T], selectors: Iterable[_U], /
112+
) -> Generator[_T, None, None]:
113+
return (datum for datum, selector in zip(data, selectors) if selector)

0 commit comments

Comments
 (0)