@@ -140,81 +140,82 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] {
140140 implicit val F : GenConcurrent [F , E ] = this
141141
142142 F .deferred[Option [E ]] flatMap { preempt =>
143- F .ref[Set [Fiber [F , ? , ? ]]](Set ()) flatMap { supervision =>
144- // has to be done in parallel to avoid head of line issues
145- val cancelAllF = supervision.get.flatMap(_.toList.parTraverse_(_.cancel))
143+ F .ref[Set [(Fiber [F , ? , ? ], Deferred [F , Outcome [F , E , B ]])]](Set ()) flatMap {
144+ supervision =>
145+ // has to be done in parallel to avoid head of line issues
146+ def cancelAll (cause : Option [E ]) = supervision.get flatMap { states =>
147+ val causeOC : Outcome [F , E , B ] = cause match {
148+ case Some (e) => Outcome .Errored (e)
149+ case None => Outcome .Canceled ()
150+ }
146151
147- MiniSemaphore [F ](n) flatMap { sem =>
148- val results = ta traverse { a =>
149- preempt.tryGet flatMap {
150- case Some (_) =>
151- // it's okay to produce never here because the early abort preceeds us
152- // this effect won't get sequenced, so it can be anything really
153- F .pure(F .never[B ])
152+ states.toList parTraverse_ {
153+ case (fiber, result) =>
154+ result.complete(causeOC) *> fiber.cancel
155+ }
156+ }
154157
155- case None =>
156- F .uncancelable { poll =>
157- F .deferred[Outcome [F , E , B ]] flatMap { result =>
158- val action = poll(sem.acquire) >> f(a)
159- .guaranteeCase { oc =>
160- val completion = oc match {
161- case Outcome .Succeeded (_) =>
162- preempt.tryGet flatMap {
163- case Some (Some (e)) =>
164- result.complete(Outcome .Errored (e))
165-
166- case Some (None ) =>
167- result.complete(Outcome .Canceled ())
168-
169- case None =>
170- result.complete(oc)
171- }
172-
173- case Outcome .Errored (e) =>
174- preempt.complete(Some (e)) flatMap { won =>
175- if (won)
176- result.complete(oc) <* cancelAllF.start // avoid deadlock
177- else
178- preempt.get flatMap {
179- case Some (e) => result.complete(Outcome .Errored (e))
180- case None => result.complete(Outcome .Canceled ())
181- }
182- }
183-
184- case Outcome .Canceled () =>
185- preempt.complete(None ) flatMap { won =>
186- if (won)
187- result.complete(oc) <* cancelAllF.start // avoid deadlock
188- else
189- preempt.get flatMap {
190- case Some (e) => result.complete(Outcome .Errored (e))
191- case None => result.complete(Outcome .Canceled ())
192- }
193- }
158+ MiniSemaphore [F ](n) flatMap { sem =>
159+ val results = ta traverse { a =>
160+ preempt.tryGet flatMap {
161+ case Some (Some (e)) => F .pure(F .raiseError[B ](e))
162+ case Some (None ) => F .pure(F .canceled *> F .never[B ])
163+
164+ case None =>
165+ F .uncancelable { poll =>
166+ F .deferred[Outcome [F , E , B ]] flatMap { result =>
167+ val action = poll(sem.acquire) >> f(a)
168+ .guaranteeCase { oc =>
169+ val completion = oc match {
170+ case Outcome .Succeeded (_) =>
171+ preempt.tryGet flatMap {
172+ case Some (Some (e)) =>
173+ result.complete(Outcome .Errored (e))
174+
175+ case Some (None ) =>
176+ result.complete(Outcome .Canceled ())
177+
178+ case None =>
179+ result.complete(oc)
180+ }
181+
182+ case Outcome .Errored (e) =>
183+ preempt
184+ .complete(Some (e))
185+ .ifM(
186+ result.complete(oc) <* cancelAll(Some (e)).start,
187+ false .pure[F ])
188+
189+ case Outcome .Canceled () =>
190+ preempt
191+ .complete(None )
192+ .ifM(
193+ result.complete(oc) <* cancelAll(None ).start,
194+ false .pure[F ])
195+ }
196+
197+ completion *> sem.release
198+ }
199+ .void
200+ .voidError
201+ .start
202+
203+ action flatMap { fiber =>
204+ supervision.update(_ + ((fiber, result))) map { _ =>
205+ result
206+ .get
207+ .flatMap(_.embed(F .canceled *> F .never))
208+ .onCancel(fiber.cancel)
209+ .guarantee(supervision.update(_ - ((fiber, result))))
194210 }
195-
196- completion *> sem.release
197- }
198- .void
199- .voidError
200- .start
201-
202- action flatMap { fiber =>
203- supervision.update(_ + fiber) map { _ =>
204- result
205- .get
206- .flatMap(_.embed(F .canceled *> F .never))
207- .onCancel(fiber.cancel)
208- .guarantee(supervision.update(_ - fiber))
209211 }
210212 }
211213 }
212- }
214+ }
213215 }
214- }
215216
216- results.flatMap(_.sequence).onCancel(cancelAllF )
217- }
217+ results.flatMap(_.sequence).onCancel(cancelAll( None ) )
218+ }
218219 }
219220 }
220221 }
@@ -229,49 +230,66 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] {
229230
230231 implicit val F : GenConcurrent [F , E ] = this
231232
232- // TODO we need to write a test for error cancelation
233233 F .deferred[Option [E ]] flatMap { preempt =>
234234 F .ref[List [Fiber [F , ? , ? ]]](Nil ) flatMap { supervision =>
235235 MiniSemaphore [F ](n) flatMap { sem =>
236+ val cancelAll = supervision.get.flatMap(_.parTraverse_(_.cancel))
237+
238+ // doesn't complete until every fiber has been at least *started*
236239 val startAll = ta traverse_ { a =>
237240 // first check to see if any of the effects have errored out
238241 // don't bother starting new things if that happens
239242 preempt.tryGet flatMap {
240- case Some (_) =>
241- F .unit // allow the error to be resurfaced later
243+ case Some (Some (e)) =>
244+ F .raiseError[Unit ](e)
245+
246+ case Some (None ) =>
247+ F .canceled
242248
243249 case None =>
244250 F .uncancelable { poll =>
245- // if the effect produces an error, race to kill all the rest
246- val wrapped = f(a) guaranteeCase { oc =>
247- sem.release *> oc.fold(
248- preempt.complete(None ).void,
249- e => preempt.complete(Some (e)).void,
250- _ => F .unit)
251+ // if the effect produces a non-success, race to kill all the rest
252+ val wrapped = f(a) guaranteeCase {
253+ case Outcome .Succeeded (_) =>
254+ F .unit
255+
256+ case Outcome .Errored (e) =>
257+ preempt.complete(Some (e)).void
258+
259+ case Outcome .Canceled () =>
260+ preempt.complete(None ).void
251261 }
252262
253- val suppressed = wrapped.void.voidError
263+ val suppressed = wrapped.void.voidError.guarantee(sem.release)
254264
255- poll(sem.acquire) > > suppressed.start flatMap { fiber =>
265+ poll(sem.acquire) * > suppressed.start flatMap { fiber =>
256266 // supervision is handled very differently here: we never remove from the set
257267 supervision.update(fiber :: _)
258268 }
259269 }
260270 }
261271 }
262272
263- val cancelAll = supervision.get.flatMap(_.parTraverse_(_.cancel))
273+ // we only run this when we know that supervision is full
274+ val awaitAll = preempt.tryGet flatMap {
275+ case Some (_) => F .unit
276+ case None =>
277+ F .race(preempt.get.void, supervision.get.flatMap(_.traverse_(_.join.void))).void
278+ }
264279
265- startAll.onCancel(cancelAll) *>
266- // we block until it's all done by acquiring all the permits
267- F .race(preempt.get *> cancelAll, sem.acquire.replicateA_(n)) *>
268- // if we hit an error or self-cancelation in any effect, resurface it here
269- // note that we can't lose errors here because of the permits: we know the fibers are done
270- preempt.tryGet flatMap {
271- case Some (Some (e)) => F .raiseError(e)
272- case Some (None ) => F .canceled
273- case None => F .unit
274- }
280+ // if we hit an error or self-cancelation in any effect, resurface it here
281+ val resurface = preempt.tryGet flatMap {
282+ case Some (Some (e)) => F .raiseError[Unit ](e)
283+ case Some (None ) => F .canceled
284+ case None => F .unit
285+ }
286+
287+ val work = (startAll *> awaitAll) guaranteeCase {
288+ case Outcome .Succeeded (_) => F .unit
289+ case Outcome .Errored (_) | Outcome .Canceled () => preempt.complete(None ) *> cancelAll
290+ }
291+
292+ work *> resurface
275293 }
276294 }
277295 }
0 commit comments