Skip to content

Commit 584ce3b

Browse files
committed
Added more tests and shuffled preemption paths to actually, you know, work
1 parent 3649a5d commit 584ce3b

File tree

2 files changed

+277
-90
lines changed

2 files changed

+277
-90
lines changed

kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala

Lines changed: 107 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -140,81 +140,82 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] {
140140
implicit val F: GenConcurrent[F, E] = this
141141

142142
F.deferred[Option[E]] flatMap { preempt =>
143-
F.ref[Set[Fiber[F, ?, ?]]](Set()) flatMap { supervision =>
144-
// has to be done in parallel to avoid head of line issues
145-
val cancelAllF = supervision.get.flatMap(_.toList.parTraverse_(_.cancel))
143+
F.ref[Set[(Fiber[F, ?, ?], Deferred[F, Outcome[F, E, B]])]](Set()) flatMap {
144+
supervision =>
145+
// has to be done in parallel to avoid head of line issues
146+
def cancelAll(cause: Option[E]) = supervision.get flatMap { states =>
147+
val causeOC: Outcome[F, E, B] = cause match {
148+
case Some(e) => Outcome.Errored(e)
149+
case None => Outcome.Canceled()
150+
}
146151

147-
MiniSemaphore[F](n) flatMap { sem =>
148-
val results = ta traverse { a =>
149-
preempt.tryGet flatMap {
150-
case Some(_) =>
151-
// it's okay to produce never here because the early abort preceeds us
152-
// this effect won't get sequenced, so it can be anything really
153-
F.pure(F.never[B])
152+
states.toList parTraverse_ {
153+
case (fiber, result) =>
154+
result.complete(causeOC) *> fiber.cancel
155+
}
156+
}
154157

155-
case None =>
156-
F.uncancelable { poll =>
157-
F.deferred[Outcome[F, E, B]] flatMap { result =>
158-
val action = poll(sem.acquire) >> f(a)
159-
.guaranteeCase { oc =>
160-
val completion = oc match {
161-
case Outcome.Succeeded(_) =>
162-
preempt.tryGet flatMap {
163-
case Some(Some(e)) =>
164-
result.complete(Outcome.Errored(e))
165-
166-
case Some(None) =>
167-
result.complete(Outcome.Canceled())
168-
169-
case None =>
170-
result.complete(oc)
171-
}
172-
173-
case Outcome.Errored(e) =>
174-
preempt.complete(Some(e)) flatMap { won =>
175-
if (won)
176-
result.complete(oc) <* cancelAllF.start // avoid deadlock
177-
else
178-
preempt.get flatMap {
179-
case Some(e) => result.complete(Outcome.Errored(e))
180-
case None => result.complete(Outcome.Canceled())
181-
}
182-
}
183-
184-
case Outcome.Canceled() =>
185-
preempt.complete(None) flatMap { won =>
186-
if (won)
187-
result.complete(oc) <* cancelAllF.start // avoid deadlock
188-
else
189-
preempt.get flatMap {
190-
case Some(e) => result.complete(Outcome.Errored(e))
191-
case None => result.complete(Outcome.Canceled())
192-
}
193-
}
158+
MiniSemaphore[F](n) flatMap { sem =>
159+
val results = ta traverse { a =>
160+
preempt.tryGet flatMap {
161+
case Some(Some(e)) => F.pure(F.raiseError[B](e))
162+
case Some(None) => F.pure(F.canceled *> F.never[B])
163+
164+
case None =>
165+
F.uncancelable { poll =>
166+
F.deferred[Outcome[F, E, B]] flatMap { result =>
167+
val action = poll(sem.acquire) >> f(a)
168+
.guaranteeCase { oc =>
169+
val completion = oc match {
170+
case Outcome.Succeeded(_) =>
171+
preempt.tryGet flatMap {
172+
case Some(Some(e)) =>
173+
result.complete(Outcome.Errored(e))
174+
175+
case Some(None) =>
176+
result.complete(Outcome.Canceled())
177+
178+
case None =>
179+
result.complete(oc)
180+
}
181+
182+
case Outcome.Errored(e) =>
183+
preempt
184+
.complete(Some(e))
185+
.ifM(
186+
result.complete(oc) <* cancelAll(Some(e)).start,
187+
false.pure[F])
188+
189+
case Outcome.Canceled() =>
190+
preempt
191+
.complete(None)
192+
.ifM(
193+
result.complete(oc) <* cancelAll(None).start,
194+
false.pure[F])
195+
}
196+
197+
completion *> sem.release
198+
}
199+
.void
200+
.voidError
201+
.start
202+
203+
action flatMap { fiber =>
204+
supervision.update(_ + ((fiber, result))) map { _ =>
205+
result
206+
.get
207+
.flatMap(_.embed(F.canceled *> F.never))
208+
.onCancel(fiber.cancel)
209+
.guarantee(supervision.update(_ - ((fiber, result))))
194210
}
195-
196-
completion *> sem.release
197-
}
198-
.void
199-
.voidError
200-
.start
201-
202-
action flatMap { fiber =>
203-
supervision.update(_ + fiber) map { _ =>
204-
result
205-
.get
206-
.flatMap(_.embed(F.canceled *> F.never))
207-
.onCancel(fiber.cancel)
208-
.guarantee(supervision.update(_ - fiber))
209211
}
210212
}
211213
}
212-
}
214+
}
213215
}
214-
}
215216

216-
results.flatMap(_.sequence).onCancel(cancelAllF)
217-
}
217+
results.flatMap(_.sequence).onCancel(cancelAll(None))
218+
}
218219
}
219220
}
220221
}
@@ -229,49 +230,66 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] {
229230

230231
implicit val F: GenConcurrent[F, E] = this
231232

232-
// TODO we need to write a test for error cancelation
233233
F.deferred[Option[E]] flatMap { preempt =>
234234
F.ref[List[Fiber[F, ?, ?]]](Nil) flatMap { supervision =>
235235
MiniSemaphore[F](n) flatMap { sem =>
236+
val cancelAll = supervision.get.flatMap(_.parTraverse_(_.cancel))
237+
238+
// doesn't complete until every fiber has been at least *started*
236239
val startAll = ta traverse_ { a =>
237240
// first check to see if any of the effects have errored out
238241
// don't bother starting new things if that happens
239242
preempt.tryGet flatMap {
240-
case Some(_) =>
241-
F.unit // allow the error to be resurfaced later
243+
case Some(Some(e)) =>
244+
F.raiseError[Unit](e)
245+
246+
case Some(None) =>
247+
F.canceled
242248

243249
case None =>
244250
F.uncancelable { poll =>
245-
// if the effect produces an error, race to kill all the rest
246-
val wrapped = f(a) guaranteeCase { oc =>
247-
sem.release *> oc.fold(
248-
preempt.complete(None).void,
249-
e => preempt.complete(Some(e)).void,
250-
_ => F.unit)
251+
// if the effect produces a non-success, race to kill all the rest
252+
val wrapped = f(a) guaranteeCase {
253+
case Outcome.Succeeded(_) =>
254+
F.unit
255+
256+
case Outcome.Errored(e) =>
257+
preempt.complete(Some(e)).void
258+
259+
case Outcome.Canceled() =>
260+
preempt.complete(None).void
251261
}
252262

253-
val suppressed = wrapped.void.voidError
263+
val suppressed = wrapped.void.voidError.guarantee(sem.release)
254264

255-
poll(sem.acquire) >> suppressed.start flatMap { fiber =>
265+
poll(sem.acquire) *> suppressed.start flatMap { fiber =>
256266
// supervision is handled very differently here: we never remove from the set
257267
supervision.update(fiber :: _)
258268
}
259269
}
260270
}
261271
}
262272

263-
val cancelAll = supervision.get.flatMap(_.parTraverse_(_.cancel))
273+
// we only run this when we know that supervision is full
274+
val awaitAll = preempt.tryGet flatMap {
275+
case Some(_) => F.unit
276+
case None =>
277+
F.race(preempt.get.void, supervision.get.flatMap(_.traverse_(_.join.void))).void
278+
}
264279

265-
startAll.onCancel(cancelAll) *>
266-
// we block until it's all done by acquiring all the permits
267-
F.race(preempt.get *> cancelAll, sem.acquire.replicateA_(n)) *>
268-
// if we hit an error or self-cancelation in any effect, resurface it here
269-
// note that we can't lose errors here because of the permits: we know the fibers are done
270-
preempt.tryGet flatMap {
271-
case Some(Some(e)) => F.raiseError(e)
272-
case Some(None) => F.canceled
273-
case None => F.unit
274-
}
280+
// if we hit an error or self-cancelation in any effect, resurface it here
281+
val resurface = preempt.tryGet flatMap {
282+
case Some(Some(e)) => F.raiseError[Unit](e)
283+
case Some(None) => F.canceled
284+
case None => F.unit
285+
}
286+
287+
val work = (startAll *> awaitAll) guaranteeCase {
288+
case Outcome.Succeeded(_) => F.unit
289+
case Outcome.Errored(_) | Outcome.Canceled() => preempt.complete(None) *> cancelAll
290+
}
291+
292+
work *> resurface
275293
}
276294
}
277295
}

0 commit comments

Comments
 (0)