Skip to content

Commit d894f3e

Browse files
author
Sune Debel
authored
improve speed of trampoline.sequence
1 parent 0027ad4 commit d894f3e

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

pfun/aio_trampoline.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from abc import ABC, abstractmethod
22
from asyncio import iscoroutine
3-
from typing import Awaitable, Callable, Generic, Iterable, TypeVar, Union, cast
3+
from typing import (Awaitable, Callable, Generic, Iterable, List, TypeVar,
4+
Union, cast)
45

56
from .immutable import Immutable
6-
from .monad import Monad, sequence_
7+
from .monad import Monad
78

89
A = TypeVar('A', covariant=True)
910
B = TypeVar('B')
@@ -127,7 +128,7 @@ async def thunk():
127128
return AndThen(self.sub, cont)
128129

129130

130-
def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
131+
def sequence(iterable: Iterable[Trampoline[B]]) -> Trampoline[Iterable[B]]:
131132
"""
132133
Evaluate each :class:`Trampoline` in `iterable` from left to right
133134
and collect the results
@@ -139,7 +140,18 @@ def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
139140
:param iterable: The iterable to collect results from
140141
:returns: ``Trampoline`` of collected results
141142
"""
142-
return cast(Trampoline[Iterable[A]], sequence_(Done, iterable))
143+
def combine(rs: Trampoline[List[B]],
144+
t: Trampoline[B]) -> Trampoline[List[B]]:
145+
return rs.and_then(
146+
lambda xs: t.map(
147+
lambda x: (xs.append(x), xs)[1] # type: ignore
148+
)
149+
)
150+
151+
result: Trampoline[List[B]] = Done([])
152+
for trampoline in iterable:
153+
result = combine(result, trampoline)
154+
return result.map(tuple)
143155

144156

145157
__all__ = ['Trampoline', 'Done', 'sequence', 'Call', 'AndThen']

tests/test_effect.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def test_from_callable_io_bound(self, f):
242242
assert effect.from_callable(effect.io_bound(f)
243243
).run(None) == f(None).get
244244

245+
@settings(deadline=None)
245246
@given(unaries())
246247
def test_catch_cpu_bound(self, f):
247248
assert effect.catch(Exception)(effect.cpu_bound(f)

0 commit comments

Comments
 (0)