diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 415d564f64..ecaccf51a1 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -2000,39 +2000,16 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, Stream.force(fstream) } - /** Interleaves the two inputs nondeterministically. The output stream - * halts after BOTH `s1` and `s2` terminate normally, or in the event - * of an uncaught failure on either `s1` or `s2`. Has the property that - * `merge(Stream.empty, s) == s` and `merge(raiseError(e), s)` will - * eventually terminate with `raiseError(e)`, possibly after emitting some - * elements of `s` first. - * - * The implementation always tries to pull one chunk from each side - * before waiting for it to be consumed by resulting stream. - * As such, there may be up to two chunks (one from each stream) - * waiting to be processed while the resulting stream - * is processing elements. - * - * Also note that if either side produces empty chunk, - * the processing on that side continues, - * w/o downstream requiring to consume result. + /** Implementation of [[merge]], however allows specifying how to combine the output stream. + * This can be used to control how chunks are emitted downstream. See [[mergeAndAwaitDownstream]] for example. * - * If either side does not emit anything (i.e. as result of drain) that side - * will continue to run even when the resulting stream did not ask for more data. - * - * Note that even when this is equivalent to `Stream(this, that).parJoinUnbounded`, - * this implementation is little more efficient - * - * @example {{{ - * scala> import scala.concurrent.duration._, cats.effect.IO, cats.effect.unsafe.implicits.global - * scala> val s1 = Stream.awakeEvery[IO](500.millis).scan(0)((acc, _) => acc + 1) - * scala> val s = s1.merge(Stream.sleep_[IO](250.millis) ++ s1) - * scala> s.take(6).compile.toVector.unsafeRunSync() - * res0: Vector[Int] = Vector(0, 0, 1, 1, 2, 2) - * }}} + * @param f The function that combines the output stream and a finalizer for the chunk. + * This way we can controll when to pull pull next chunk from upstream. */ - def merge[F2[x] >: F[x], O2 >: O]( + private def merge_[F2[x] >: F[x], O2 >: O]( that: Stream[F2, O2] + )( + f: (Stream[F2, O2], F2[Unit]) => Stream[F2, O2] )(implicit F: Concurrent[F2]): Stream[F2, O2] = Stream.force { // `State` describes the state of an upstream stream (`this` and `that` are both upstream streams) @@ -2063,12 +2040,10 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, case (Some(r1), Some(r2)) => CompositeFailure.fromResults(r1, r2) } def run(s: Stream[F2, O2]): F2[Unit] = - // `guard` ensures we do not pull another chunk until the previous one has been consumed downstream. + // `guard` ensures we do not pull another chunk until the previous one has been produced for downstream. Semaphore[F2](1).flatMap { guard => - def sendChunk(chk: Chunk[O2]): F2[Unit] = { - val outStr = Stream.chunk(chk).onFinalize(guard.release) - output.send(outStr) >> guard.acquire - } + def sendChunk(chk: Chunk[O2]): F2[Unit] = + output.send(f(Stream.chunk(chk), guard.release)) >> guard.acquire (Stream.exec(guard.acquire) ++ s.chunks.foreach(sendChunk)) // Stop when the other upstream has errored or the downstream has completed. @@ -2103,6 +2078,65 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, } } + /** Like [[merge]], but ensures that each chunk is fully consumed downstream before pulling the next chunk from the same side. + * This looses the equivalence with `Stream(this, that).parJoinUnbounded` but can be useful when we need to never read ahead from + * the merged streams. + * + * @note Pay attention to possible deadlocks of "this" or "that" when using this function, notably in parallel processing + * as unless the chunk is fully processed / scope of the chunk is released, the next chunk will not be pulled. + * + * @example {{{ + * scala> import scala.concurrent.duration._, cats.effect.IO, cats.effect.unsafe.implicits.global + * scala> import cats.effect._ + * scala> Ref.of[IO, Int](0).flatMap{ ref => + * | fs2.Stream.never[IO].mergeAndAwaitDownstream(fs2.Stream.repeatEval(ref.get)).evalMap(value => { + * | IO.sleep(1.second) >> ref.set(value + 1) as value + * | }).take(6).compile.toVector + * | }.unsafeRunSync() + * res0: Vector[Int] = Vector(0, 1, 2, 3, 4, 5) + * }}} + */ + def mergeAndAwaitDownstream[F2[x] >: F[x], O2 >: O]( + that: Stream[F2, O2] + )(implicit F: Concurrent[F2]): Stream[F2, O2] = + merge_(that) { case (s, fin) => s.onFinalize(fin) } + + /** Interleaves the two inputs nondeterministically. The output stream + * halts after BOTH `s1` and `s2` terminate normally, or in the event + * of an uncaught failure on either `s1` or `s2`. Has the property that + * `merge(Stream.empty, s) == s` and `merge(raiseError(e), s)` will + * eventually terminate with `raiseError(e)`, possibly after emitting some + * elements of `s` first. + * + * The implementation always tries to pull one chunk from each side + * before waiting for it to be consumed by resulting stream. + * As such, there may be up to two chunks (one from each stream) + * waiting to be processed while the resulting stream + * is processing elements. + * + * Also note that if either side produces empty chunk, + * the processing on that side continues, + * w/o downstream requiring to consume result. + * + * If either side does not emit anything (i.e. as result of drain) that side + * will continue to run even when the resulting stream did not ask for more data. + * + * Note that even when this is equivalent to `Stream(this, that).parJoinUnbounded`, + * this implementation is little more efficient + * + * @example {{{ + * scala> import scala.concurrent.duration._, cats.effect.IO, cats.effect.unsafe.implicits.global + * scala> val s1 = Stream.awakeEvery[IO](500.millis).scan(0)((acc, _) => acc + 1) + * scala> val s = s1.merge(Stream.sleep_[IO](250.millis) ++ s1) + * scala> s.take(6).compile.toVector.unsafeRunSync() + * res0: Vector[Int] = Vector(0, 0, 1, 1, 2, 2) + * }}} + */ + def merge[F2[x] >: F[x], O2 >: O]( + that: Stream[F2, O2] + )(implicit F: Concurrent[F2]): Stream[F2, O2] = + merge_(that) { case (s, fin) => Stream.exec(fin) ++ s } + /** Like `merge`, but halts as soon as _either_ branch halts. */ def mergeHaltBoth[F2[x] >: F[x]: Concurrent, O2 >: O]( that: Stream[F2, O2] diff --git a/core/shared/src/test/scala/fs2/StreamMergeSuite.scala b/core/shared/src/test/scala/fs2/StreamMergeSuite.scala index a81781533a..32af3b77a3 100644 --- a/core/shared/src/test/scala/fs2/StreamMergeSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamMergeSuite.scala @@ -21,10 +21,10 @@ package fs2 -import scala.concurrent.duration._ - +import scala.concurrent.duration.* import cats.effect.IO import cats.effect.kernel.{Deferred, Ref} +import cats.effect.testkit.TestControl import org.scalacheck.effect.PropF.forAllF class StreamMergeSuite extends Fs2Suite { @@ -224,7 +224,7 @@ class StreamMergeSuite extends Fs2Suite { } } - test("merge not emit ahead") { + test("merge not emit ahead more than 1 chunk") { forAllF { (v: Int) => Ref .of[IO, Int](v) @@ -236,9 +236,100 @@ class StreamMergeSuite extends Fs2Suite { .repeatEval(ref.get) .merge(Stream.never[IO]) .evalMap(sleepAndSet) - .take(2) - .assertEmits(List(v, v + 1)) + .take(6) + .assertEmits(List(v, v, v + 1, v + 1, v + 2, v + 2)) + } + } + } + + test("mergeAndAwaitDownstream not emit ahead") { + forAllF { (v: Int) => + Ref + .of[IO, Int](v) + .flatMap { ref => + def sleepAndSet(value: Int): IO[Int] = + IO.sleep(100.milliseconds) >> ref.set(value + 1) >> IO(value) + + Stream + .repeatEval(ref.get) + .mergeAndAwaitDownstream(Stream.never[IO]) + .evalMap(sleepAndSet) + .take(3) + .assertEmits(List(v, v + 1, v + 2)) } } } + + test("merge produces when concurrently handled") { + + // Create stream for each int that comes in, + // then run them in parallel + // Where we return the int value and then wait (Simulating some work that never ends, or ends in long time.). + def run(source: Stream[IO, Int]): IO[Vector[Int]] = + source + .map { a => + Stream.emit(a) ++ + Stream.never[IO] + } + .parJoinUnbounded + .timeoutOnPullTo(200.millis, Stream.empty) + .compile + .toVector + + TestControl + .executeEmbed( + run( + (Stream.emit(1) ++ Stream.sleep_[IO](50.millis) ++ Stream.emit(2)).merge( + Stream.never[IO] + ) + ) + ) + .assertEquals(Vector(1, 2)) + } + + test("issue #3598") { + + sealed trait Data + + case class Item(value: Int) extends Data + case object Tick1 extends Data + case object Tick2 extends Data + + def splitHead[F[_], O](in: fs2.Stream[F, O]): fs2.Stream[F, (O, fs2.Stream[F, O])] = + in.pull.uncons1.flatMap { + case Some((head, tail)) => fs2.Pull.output(Chunk((head, tail))) + case None => fs2.Pull.done + }.stream + + val source = + Stream.emits(1 to 2).evalMap(i => IO(Item(i)).delayBy(100.millis)) ++ Stream.never[IO] + + val timer = fs2.Stream.awakeEvery[IO](50.millis).map(_ => Tick1) + val timer2 = fs2.Stream.awakeEvery[IO](50.millis).map(_ => Tick2) + + val sources = timer2.mergeHaltBoth(source.mergeHaltBoth(timer)) + + val program = + splitHead(sources) + .flatMap { case (head, tail) => + splitHead(tail) + .flatMap { case (head2, tail) => + Stream.emit(head) ++ Stream.emit(head2) ++ tail + } + .parEvalMap(3) { i => + IO(i) + } + } + .interruptAfter(230.millis) + .compile + .toVector + + TestControl + .executeEmbed(program) + .assert { data => + data.count(_.isInstanceOf[Item]) == 2 && + data.count(_.isInstanceOf[Tick1.type]) == 4 && + data.count(_.isInstanceOf[Tick2.type]) == 4 + } + } } diff --git a/core/shared/src/test/scala/fs2/TimedPullsSuite.scala b/core/shared/src/test/scala/fs2/TimedPullsSuite.scala index 8d52d18716..68f92a4387 100644 --- a/core/shared/src/test/scala/fs2/TimedPullsSuite.scala +++ b/core/shared/src/test/scala/fs2/TimedPullsSuite.scala @@ -313,8 +313,19 @@ class TimedPullsSuite extends Fs2Suite { } test("After the first uncons, timeouts start immediately") { + // Time how often we generate data in the main stream. + // This is only started after the first uncons. val emissionTime = 100.millis - val timeout = 200.millis + + // Timeout which is registered before the first uncons, it is registered immediately + // But we do not expect it to trigger. + // This has to be longer than emissionTime, otherwise the first uncons would always timeout. + val initialTimeout = 200.millis + + // Timeout registered after the first uncons, this one should be fired + val timeout = 50.millis + + // Time we wait before doing uncons. val timedPullPause = Pull.eval(IO.sleep(150.millis)) val prog = @@ -323,7 +334,7 @@ class TimedPullsSuite extends Fs2Suite { .repeatN(2) .pull .timed { tp => - tp.timeout(timeout) >> + tp.timeout(initialTimeout) >> // If the first timeout started immediately, this pause // before uncons would cause a timeout to be emitted timedPullPause >>