Skip to content

Commit 1693352

Browse files
committed
ability to backpressure via queue, and propagate on cancellation or error
1 parent 5ba5055 commit 1693352

File tree

2 files changed

+71
-52
lines changed

2 files changed

+71
-52
lines changed

core/shared/src/main/scala/fs2/Stream.scala

Lines changed: 35 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,80 +1510,63 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
15101510
timeout: FiniteDuration
15111511
)(implicit F: Temporal[F2]): Stream[F2, Chunk[O]] = {
15121512

1513-
case class State[+A](os: Chunk[A], supplyEnded: Boolean) {
1514-
1515-
// checking if it's empty to avoid early termination of the stream if the producer is faster than consumers
1516-
def streamExhausted: Boolean = supplyEnded && os.isEmpty
1517-
1518-
def endSupply: State[A] = copy(supplyEnded = true)
1519-
1520-
def splitAt(idx: Long): (State[A], Chunk[A]) = {
1521-
val (flushed, kept) = os.splitAt(idx.toInt)
1522-
(copy(os = kept), flushed)
1523-
}
1524-
1525-
override def toString =
1526-
s"State(size = ${os.size}, supplyEnded = $supplyEnded, streamExhausted = $streamExhausted), os = $os"
1527-
}
1528-
1529-
object State {
1530-
def add[A](a: A)(s: State[A]): State[A] = s.copy(os = s.os ++ Chunk.singleton(a))
1531-
}
1513+
val groupSize: Long = chunkSize.toLong
15321514

15331515
if (timeout.toNanos == 0) chunkN(chunkSize)
15341516
else
15351517
Stream.force {
15361518
for {
1537-
state <- Ref.of(State[O](os = Chunk.empty, supplyEnded = false))
15381519
supply <- Semaphore[F2](0)
1520+
supplyEnded <- Ref.of(false)
1521+
buffer <- Queue.bounded[F2, O](chunkSize) // buffering and backpressure
15391522
} yield {
15401523

1541-
val groupSize = chunkSize.toLong
1542-
1543-
val enqueue: F2[Unit] =
1544-
foreach(o => state.update(State.add[O](o)) *> supply.release)
1545-
.covary[F2]
1546-
.compile
1547-
.drain
1524+
val emitChunk: F2[Chunk[O]] =
1525+
buffer.tryTakeN(Some(groupSize.toInt)).map(Chunk.seq)
15481526

1549-
// run at the end when there's no need to wait
1550-
val maxOutSupply = supply.releaseN(Int.MaxValue)
1527+
// we need to check the buffer size, rather than the available supply since
1528+
// the supply is maxed out at the end so won't always indicate accurately the buffer size
1529+
val isBufferEmpty: F2[Boolean] =
1530+
buffer.size.map(_ == 0)
15511531

1552-
val markSupplyEnd = state.update(_.endSupply)
1532+
val streamExhausted: F2[Boolean] =
1533+
(isBufferEmpty, supplyEnded.get).mapN(_ && _)
15531534

1554-
val enqueueAsync = F.start(enqueue.guarantee(markSupplyEnd *> maxOutSupply))
1535+
// releasing a number of permits equal to {groupSize} should be enough,
1536+
// but in order to ensure termination of the consumer when the producer
1537+
// stream is cancelled we need to max out the supply
1538+
val maxOutSupply: F2[Unit] =
1539+
supply.available.flatMap(a => supply.releaseN(Long.MaxValue - a))
15551540

1556-
// returns the number of elements to be flushed on the left &
1557-
// the number of permits the supply should be lowered by on the right
1558-
val awaitNext: F2[(Long, Long)] = {
1541+
// enabling termination of the consumer stream when the producer completes
1542+
// naturally (i.e runs out of elements)
1543+
val markSupplyEnd: F2[Unit] = supplyEnded.set(true)
15591544

1560-
val nextChunk = for {
1561-
_ <- supply.acquire
1562-
a <- supply.available.map(_ + 1) // buffer size = acquired (1) + available
1563-
flushSize = a.min(groupSize) // apply cap (producer is still running)
1564-
} yield (flushSize, flushSize - 1) // decrement permits (one already acquired)
1565-
1566-
F.ifM(supply.tryAcquireN(groupSize))(F.pure((groupSize, groupSize)), nextChunk)
1567-
}
1545+
val enqueue: F2[Unit] =
1546+
foreach(buffer.offer(_) <* supply.release).compile.drain
1547+
.guarantee(markSupplyEnd *> maxOutSupply)
15681548

1569-
def emitChunk(n: Long): F2[Chunk[O]] = state.modify(_.splitAt(n))
1549+
val awaitAndEmitNext: F2[Chunk[O]] = for {
1550+
isEmpty <- isBufferEmpty
1551+
awaited <- supply.acquire.whenA(isEmpty).as(if (isEmpty) 1 else 0)
1552+
flushed <- emitChunk
1553+
// lower supply by {flushed.size} (excluding element already awaited)
1554+
_ <- supply.tryAcquireN((flushed.size.toLong - awaited).max(0))
1555+
} yield flushed
15701556

15711557
val onTimeout =
1572-
F.ifM(state.get.map(_.streamExhausted))(
1573-
F.pure(Chunk.empty[O]), // emit an empty chunk to end the stream
1574-
awaitNext.flatMap { case (n, p) => emitChunk(n) <* supply.tryAcquireN(p) }
1575-
)
1558+
F.ifM(streamExhausted)(F.pure(Chunk.empty[O]), awaitAndEmitNext)
15761559

15771560
val dequeue: F2[Chunk[O]] =
15781561
F.race(supply.acquireN(groupSize), F.sleep(timeout)).flatMap {
1579-
case Left(_) => emitChunk(groupSize)
1562+
case Left(_) => emitChunk
15801563
case Right(_) => onTimeout
15811564
}
15821565

1583-
Stream.bracket(enqueueAsync)(fib => markSupplyEnd *> fib.cancel) >>
1584-
Stream
1585-
.repeatEval(dequeue)
1586-
.collectWhile { case os if os.nonEmpty => os }
1566+
Stream
1567+
.repeatEval(dequeue)
1568+
.collectWhile { case os if os.nonEmpty => os }
1569+
.concurrently(Stream.eval(enqueue))
15871570
}
15881571
}
15891572
}

core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.scalacheck.Prop.forAll
3434

3535
import scala.concurrent.duration._
3636
import scala.concurrent.TimeoutException
37+
import scala.util.control.NoStackTrace
3738

3839
class StreamCombinatorsSuite extends Fs2Suite {
3940

@@ -831,6 +832,41 @@ class StreamCombinatorsSuite extends Fs2Suite {
831832
)
832833
.assertEquals(0.millis)
833834
}
835+
836+
test("Propagation: upstream failures are propagated downstream") {
837+
838+
case object SevenNotAllowed extends NoStackTrace
839+
840+
val source = Stream
841+
.unfold(0)(s => Some(s, s + 1))
842+
.covary[IO]
843+
.evalMap(n => if (n == 7) IO.raiseError(SevenNotAllowed) else IO.pure(n))
844+
845+
val downstream = source.groupWithin(100, 2.seconds)
846+
847+
downstream.compile.lastOrError.intercept[SevenNotAllowed.type]
848+
}
849+
850+
test("Propagation: upstream cancellation is propagated downstream") {
851+
852+
def source(counter: Ref[IO, Int]): Stream[IO, Int] = {
853+
Stream
854+
.unfold(0)(s => Some(s, s + 1))
855+
.covary[IO]
856+
.meteredStartImmediately(1.second)
857+
.evalTap(counter.set)
858+
.interruptAfter(5.5.seconds)
859+
}
860+
861+
def downstream(counter: Ref[IO, Int]): Stream[IO, Chunk[Int]] =
862+
source(counter).groupWithin(Int.MaxValue, 1.day)
863+
864+
(for {
865+
counter <- Ref.of[IO, Int](0)
866+
_ <- downstream(counter).compile.drain
867+
c <- counter.get
868+
} yield c).assertEquals(5)
869+
}
834870
}
835871

836872
property("head")(forAll((s: Stream[Pure, Int]) => assertEquals(s.head.toList, s.toList.take(1))))

0 commit comments

Comments
 (0)