Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package cats.effect.benchmarks

import cats.effect.IO
import cats.effect.syntax.all._
import cats.effect.unsafe.implicits.global
import cats.implicits.{catsSyntaxParallelTraverse1, toTraverseOps}

import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra.Blackhole

import scala.concurrent.duration._

import java.util.concurrent.TimeUnit

/**
Expand Down Expand Up @@ -55,6 +58,24 @@ class ParallelBenchmark {
def parTraverse(): Unit =
1.to(size).toList.parTraverse(_ => IO(Blackhole.consumeCPU(cpuTokens))).void.unsafeRunSync()

@Benchmark
def parTraverseN(): Unit =
1.to(size)
.toList
.parTraverseN(size / 100)(_ => IO(Blackhole.consumeCPU(cpuTokens)))
.void
.unsafeRunSync()

@Benchmark
def parTraverseNCancel(): Unit = {
val e = new RuntimeException
val test = 1.to(size * 100).toList.parTraverseN(size / 100) { _ =>
IO.sleep(100.millis) *> IO.raiseError(e)
}

test.attempt.void.unsafeRunSync()
}

@Benchmark
def traverse(): Unit =
1.to(size).toList.traverse(_ => IO(Blackhole.consumeCPU(cpuTokens))).void.unsafeRunSync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,52 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] {

implicit val F: GenConcurrent[F, E] = this

MiniSemaphore[F](n).flatMap { sem => ta.parTraverse { a => sem.withPermit(f(a)) } }
F.deferred[Option[E]] flatMap { preempt =>
F.ref[Set[Fiber[F, ?, ?]]](Set()) flatMap { supervision =>
MiniSemaphore[F](n) flatMap { sem =>
val results = ta traverse { a =>
preempt.tryGet flatMap {
case Some(_) =>
// it's okay to produce never here because the early abort preceeds us
// this effect won't get sequenced, so it can be anything really
F.pure(F.never[B])

case None =>
F.uncancelable { poll =>
F.deferred[Outcome[F, E, B]] flatMap { result =>
val action = poll(sem.acquire) >> f(a)
.guaranteeCase { oc =>
result.complete(oc) *> oc.fold(
preempt.complete(None).void,
e => preempt.complete(Some(e)).void,
_ => F.unit) *> sem.release
}
.void
.voidError
.start

action flatMap { fiber =>
supervision.update(_ + fiber) map { _ =>
result
.get
.flatMap(_.embed(F.canceled *> F.never))
.onCancel(fiber.cancel)
.guarantee(supervision.update(_ - fiber))
}
}
}
}
}
}

results.flatMap(_.sequence) guaranteeCase {
case Outcome.Succeeded(_) => F.unit
// has to be done in parallel to avoid head of line issues
case _ => supervision.get.flatMap(_.toList.parTraverse_(_.cancel))
}
}
}
}
}

/**
Expand All @@ -152,7 +197,52 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] {

implicit val F: GenConcurrent[F, E] = this

MiniSemaphore[F](n).flatMap { sem => ta.parTraverse_ { a => sem.withPermit(f(a)) } }
// TODO we need to write a test for error cancelation
F.deferred[Option[E]] flatMap { preempt =>
F.ref[List[Fiber[F, ?, ?]]](Nil) flatMap { supervision =>
MiniSemaphore[F](n) flatMap { sem =>
val startAll = ta traverse_ { a =>
// first check to see if any of the effects have errored out
// don't bother starting new things if that happens
preempt.tryGet flatMap {
case Some(_) =>
F.unit // allow the error to be resurfaced later

case None =>
F.uncancelable { poll =>
// if the effect produces an error, race to kill all the rest
val wrapped = f(a) guaranteeCase { oc =>
sem.release *> oc.fold(
preempt.complete(None).void,
e => preempt.complete(Some(e)).void,
_ => F.unit)
}

val suppressed = wrapped.void.voidError

poll(sem.acquire) >> suppressed.start flatMap { fiber =>
// supervision is handled very differently here: we never remove from the set
supervision.update(fiber :: _)
}
}
}
}

val cancelAll = supervision.get.flatMap(_.parTraverse_(_.cancel))

startAll.onCancel(cancelAll) *>
// we block until it's all done by acquiring all the permits
F.race(preempt.get *> cancelAll, sem.acquire.replicateA_(n)) *>
// if we hit an error or self-cancelation in any effect, resurface it here
// note that we can't lose errors here because of the permits: we know the fibers are done
preempt.tryGet flatMap {
case Some(Some(e)) => F.raiseError(e)
case Some(None) => F.canceled
case None => F.unit
}
}
}
}
}

override def racePair[A, B](fa: F[A], fb: F[B])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ import scala.collection.immutable.{Queue => ScalaQueue}
* A cut-down version of semaphore used to implement parTraverseN
*/
private[kernel] abstract class MiniSemaphore[F[_]] extends Serializable {
def acquire: F[Unit]
def release: F[Unit]

/**
* Sequence an action while holding a permit
*/
def withPermit[A](fa: F[A]): F[A]
}

Expand Down
140 changes: 140 additions & 0 deletions tests/shared/src/test/scala/cats/effect/IOSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,146 @@ class IOSpec extends BaseSpec with Discipline with IOPlatformSpecification {
p must completeAs(true)
}

"run finalizers when canceled" in ticked { implicit ticker =>
val p = for {
r <- IO.ref(0)

/*
* The exact series of steps here is:
*
* List(IO.never.onCancel, IO.unit, IO.never.onCancel)
*
* This is significant because we're limiting the parallelism to
* 2, meaning that we will hit a wall after IO.unit. HOWEVER,
* IO.unit completes immediately, so this test not only checks
* cancelation, it also tests that we move onto the third item
* after the second one completes even while the first is blocked.
* In other words, it's testing both cancelation and head of line
* behavior.
*/
f <- List(1, 2, 3)
.parTraverseN(2) { i =>
if (i == 2) IO.unit
else IO.never.onCancel(r.update(_ + 1))
}
.start

_ <- IO.sleep(100.millis)
_ <- f.cancel
c <- r.get
_ <- IO { c mustEqual 2 }
} yield true

p must completeAs(true)
}

"propagate self-cancellation" in ticked { implicit ticker =>
List(1, 2, 3, 4)
.parTraverseN(2) { (n: Int) =>
if (n == 3) IO.canceled *> IO.never
else IO.pure(n)
}
.void must selfCancel
}

"run finalizers when a task self-cancels" in ticked { implicit ticker =>
val p = for {
r <- IO.ref(0)
fib <- List(1, 2, 3, 4)
.parTraverseN(2) { (n: Int) =>
if (n == 3) IO.canceled *> IO.never
else IO.pure(n)
}
.onCancel(r.update(_ + 1))
.void
.start
_ <- IO.sleep(100.millis)
c <- r.get
_ <- IO { c mustEqual 1 }
oc <- fib.join
} yield oc.isCanceled

p must completeAs(true)
}

"not run more than `n` tasks at a time" in real {
def task(counter: Ref[IO, Int], maximum: Ref[IO, Int]): IO[Unit] = {
val acq = counter.updateAndGet(_ + 1).flatMap { count =>
maximum.update { max => if (count > max) count else max }
}
IO.asyncForIO.bracket(acq) { _ => IO.sleep(100.millis) }(_ => counter.update(_ - 1))
}

for {
maximum <- Ref.of[IO, Int](0)
counter <- Ref.of[IO, Int](0)
nCpu <- IO { Runtime.getRuntime().availableProcessors() }
n = java.lang.Math.max(nCpu, 2)
size = 4 * n
res <- (1 to size).toList.parTraverseN(n) { _ => task(counter, maximum) }
_ <- IO { res.size mustEqual size }
count <- counter.get
_ <- IO { count mustEqual 0 }
max <- maximum.get
_ <- IO { max must beLessThanOrEqualTo(n) }
} yield ok
}

"run actually in parallel" in real {
val n = 4
(1 to 2 * n)
.toList
.map { i => IO.sleep(1.second).as(i) }
.parSequenceN(n)
.timeout(3.seconds)
.flatMap { res => IO { res mustEqual (1 to 2 * n).toList } }
}

"work for empty traverse" in ticked { implicit ticker =>
List.empty[Int].parTraverseN(4) { _ => IO.never[String] } must completeAs(
List.empty[String])
}

"work for non-empty traverse (ticked)" in ticked { implicit ticker =>
List(1).parTraverseN(4) { i => IO.pure(i.toString) } must completeAs(List("1"))
List(1, 2).parTraverseN(3) { i => IO.pure(i.toString) } must completeAs(List("1", "2"))
List(1, 2, 3).parTraverseN(2) { i => IO.pure(i.toString) } must completeAs(
List("1", "2", "3"))
List(1, 2, 3, 4).parTraverseN(1) { i => IO.pure(i.toString) } must completeAs(
List("1", "2", "3", "4"))
}

"work for non-empty traverse (real)" in real {
for {
_ <- List(1).parTraverseN(4)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual List("1"))
}
_ <- List(1, 2).parTraverseN(3)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual List("1", "2"))
}
_ <- List(1, 2, 3).parTraverseN(2)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual List("1", "2", "3"))
}
_ <- List(1, 2, 3, 4).parTraverseN(1)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual List("1", "2", "3", "4"))
}
_ <- (1 to 10000).toList.parTraverseN(2)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual (1 to 10000).map(_.toString).toList)
}
} yield ok
}

"be null-safe" in real {
for {
r1 <- List[String]("a", "b", null, "d", null).parTraverseN(2) {
case "a" => IO.pure(null)
case "b" => IO.pure("x")
case "d" => IO.pure(null)
case null => IO.pure("z")
}
_ <- IO { r1 mustEqual List(null, "x", "z", null, "z") }
} yield ok
}
}

"parTraverseN_" should {
Expand Down
Loading