diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 96f09ec383..cd35ba920f 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -1406,103 +1406,64 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, timeout: FiniteDuration )(implicit F: Temporal[F2]): Stream[F2, Chunk[O]] = { - case class JunctionBuffer[T]( - data: Vector[T], - endOfSupply: Option[Either[Throwable, Unit]], - endOfDemand: Option[Either[Throwable, Unit]] - ) { - def splitAt(n: Int): (JunctionBuffer[T], JunctionBuffer[T]) = - if (this.data.size >= n) { - val (head, tail) = this.data.splitAt(n.toInt) - (this.copy(tail), this.copy(head)) - } else { - (this.copy(Vector.empty), this) - } - } + val groupSize: Long = chunkSize.toLong - val outputLong = chunkSize.toLong - fs2.Stream.force { - for { - demand <- Semaphore[F2](outputLong) - supply <- Semaphore[F2](0L) - buffer <- Ref[F2].of( - JunctionBuffer[O](Vector.empty[O], endOfSupply = None, endOfDemand = None) - ) - } yield { - /* - Buffer: stores items from input to be sent on next output chunk - * - Demand Semaphore: to avoid adding too many items to buffer - * - Supply: counts filled positions for next output chunk */ - def enqueue(t: O): F2[Boolean] = - for { - _ <- demand.acquire - buf <- buffer.modify(buf => (buf.copy(buf.data :+ t), buf)) - _ <- supply.release - } yield buf.endOfDemand.isEmpty - - val dequeueNextOutput: F2[Option[Vector[O]]] = { - // Trigger: waits until the supply buffer is full (with acquireN) - val waitSupply = supply.acquireN(outputLong).guaranteeCase { - case Outcome.Succeeded(_) => supply.releaseN(outputLong) - case _ => F.unit - } + if (timeout.toNanos == 0) chunkN(chunkSize) + else + Stream.force { + for { + supply <- Semaphore[F2](0) + supplyEnded <- Ref.of[F2, Boolean](false) + buffer <- Queue.bounded[F2, O](chunkSize) // buffering and backpressure + } yield { - val onTimeout: F2[Long] = - for { - _ <- supply.acquire // waits until there is at least one element in buffer - m <- supply.available - k = m.min(outputLong - 1) - b <- supply.tryAcquireN(k) - } yield if (b) k + 1 else 1 - - // in JS cancellation doesn't always seem to run, so race conditions should restore state on their own - for { - acq <- F.race(F.sleep(timeout), waitSupply).flatMap { - case Left(_) => onTimeout - case Right(_) => supply.acquireN(outputLong).as(outputLong) - } - buf <- buffer.modify(_.splitAt(acq.toInt)) - _ <- demand.releaseN(buf.data.size.toLong) - res <- buf.endOfSupply match { - case Some(Left(error)) => F.raiseError(error) - case Some(Right(_)) if buf.data.isEmpty => F.pure(None) - case _ => F.pure(Some(buf.data)) + val emitChunk: F2[Chunk[O]] = + buffer.tryTakeN(Some(groupSize.toInt)).map(Chunk.seq) + + // we need to check the buffer size, rather than the available supply since + // the supply is increased at the end so it won't always report the buffer size accurately + val isBufferEmpty: F2[Boolean] = + buffer.size.map(_ == 0) + + val streamExhausted: F2[Boolean] = + (isBufferEmpty, supplyEnded.get).mapN(_ && _) + + // releasing a number of permits equal to {groupSize} is enough in most cases, but in + // order to ensure prompt termination of the consumer on interruption even when the timeout + // has not kicked in yet nor we've seen enough elements we need to max out the supply + val maxOutSupply: F2[Unit] = + supply.available.flatMap(av => supply.releaseN(Long.MaxValue - av)) + + // enabling termination of the consumer stream when the producer completes naturally + // (i.e runs out of elements) or when the combined stream (consumer + producer) is interrupted + val endSupply: F2[Unit] = supplyEnded.set(true) *> maxOutSupply + + val enqueue: F2[Unit] = + foreach(buffer.offer(_) <* supply.release).compile.drain.guarantee(endSupply) + + val awaitAndEmitNext: F2[Chunk[O]] = for { + isEmpty <- isBufferEmpty + awaited <- supply.acquire.whenA(isEmpty).as(if (isEmpty) 1 else 0) + flushed <- emitChunk + // lower supply by {flushed.size} (excluding element already awaited) + _ <- supply.acquireN((flushed.size.toLong - awaited).max(0)) + } yield flushed + + val onTimeout: F2[Chunk[O]] = + F.ifM(streamExhausted)(F.pure(Chunk.empty[O]), awaitAndEmitNext) + + val dequeue: F2[Chunk[O]] = + F.race(supply.acquireN(groupSize), F.sleep(timeout)).flatMap { + case Left(_) => emitChunk + case Right(_) => onTimeout } - } yield res - } - - def endSupply(result: Either[Throwable, Unit]): F2[Unit] = - buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(Int.MaxValue) - - def endDemand(result: Either[Throwable, Unit]): F2[Unit] = - buffer.update(_.copy(endOfDemand = Some(result))) *> demand.releaseN(Int.MaxValue) - def toEnding(ec: ExitCase): Either[Throwable, Unit] = ec match { - case ExitCase.Succeeded => Right(()) - case ExitCase.Errored(e) => Left(e) - case ExitCase.Canceled => Right(()) - } - - val enqueueAsync = F.start { - this - .evalMap(enqueue) - .forall(identity) - .onFinalizeCase(ec => endSupply(toEnding(ec))) - .compile - .drain - } - - val outputStream: Stream[F2, Chunk[O]] = Stream - .eval(dequeueNextOutput) - .repeat - .collectWhile { case Some(data) => Chunk.vector(data) } - - Stream - .bracketCase(enqueueAsync) { case (upstream, exitCase) => - endDemand(toEnding(exitCase)) *> upstream.cancel - } >> outputStream + .repeatEval(dequeue) + .collectWhile { case os if os.nonEmpty => os } + .concurrently(Stream.eval(enqueue)) + } } - } } /** If `this` terminates with `Stream.raiseError(e)`, invoke `h(e)`. diff --git a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala index 8e703421bb..e60168647b 100644 --- a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala @@ -34,6 +34,7 @@ import org.scalacheck.Prop.forAll import scala.concurrent.duration._ import scala.concurrent.TimeoutException +import scala.util.control.NoStackTrace class StreamCombinatorsSuite extends Fs2Suite { @@ -831,6 +832,59 @@ class StreamCombinatorsSuite extends Fs2Suite { ) .assertEquals(0.millis) } + + test("upstream failures are propagated downstream") { + + case object SevenNotAllowed extends NoStackTrace + + val source = Stream + .unfold(0)(s => Some((s, s + 1))) + .covary[IO] + .evalMap(n => if (n == 7) IO.raiseError(SevenNotAllowed) else IO.pure(n)) + + val downstream = source.groupWithin(100, 2.seconds) + + downstream.compile.lastOrError.intercept[SevenNotAllowed.type] + } + + test( + "upstream interruption causes immediate downstream termination with all elements being emitted" + ) { + + val sourceTimeout = 5.5.seconds + val downstreamTimeout = sourceTimeout + 2.seconds + + TestControl + .executeEmbed( + Ref[IO] + .of(0.millis) + .flatMap { ref => + val source: Stream[IO, Int] = + Stream + .unfold(0)(s => Some((s, s + 1))) + .covary[IO] + .meteredStartImmediately(1.second) + .interruptAfter(sourceTimeout) + + // large chunkSize and timeout (no emissions expected in the window + // specified, unless source ends, due to interruption or + // natural termination (i.e runs out of elements) + val downstream: Stream[IO, Chunk[Int]] = + source.groupWithin(Int.MaxValue, 1.day) + + downstream.compile.lastOrError + .map(_.toList) + .timeout(downstreamTimeout) + .flatTap(_ => IO.monotonic.flatMap(ref.set)) + .flatMap(emit => ref.get.map(timeLapsed => (timeLapsed, emit))) + } + ) + .assertEquals( + // downstream ended immediately (i.e timeLapsed = sourceTimeout) + // emitting whatever was accumulated at the time of interruption + (sourceTimeout, List(0, 1, 2, 3, 4, 5)) + ) + } } property("head")(forAll((s: Stream[Pure, Int]) => assertEquals(s.head.toList, s.toList.take(1))))