Skip to content

Commit 29e4b62

Browse files
Fix race condition between Topic.publish1 and Topic.close
1 parent a30644f commit 29e4b62

File tree

1 file changed

+49
-27
lines changed

1 file changed

+49
-27
lines changed

core/shared/src/main/scala/fs2/concurrent/Topic.scala

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -152,33 +152,45 @@ object Topic {
152152
(
153153
F.ref(State.initial[F, A]),
154154
SignallingRef[F, Int](0),
155+
F.deferred[Unit],
155156
F.deferred[Unit]
156-
).mapN { case (state, subscriberCount, signalClosure) =>
157+
).mapN { case (state, subscriberCount, signalClosure, publishersFinished) =>
157158
new Topic[F, A] {
158159

159160
def foreach[B](lm: LongMap[B])(f: B => F[Unit]) =
160-
lm.foldLeft(F.unit) { case (op, (_, b)) => op >> f(b) }
161+
lm.foldLeft(F.unit) { case (op, (_, b)) => f(b) >> op }
161162

162163
def publish1(a: A): F[Either[Topic.Closed, Unit]] =
163-
state.get.flatMap {
164-
case State.Closed() =>
165-
Topic.closed.pure[F]
166-
case State.Active(subs, _) =>
167-
subs.foldLeft(F.pure(Topic.rightUnit)) { case (acc, (_, chan)) =>
168-
acc.flatMap {
169-
case Left(Topic.Closed) => Topic.closed.pure[F]
170-
case Right(_) =>
171-
chan.send(a).flatMap {
172-
case Right(_) => Topic.rightUnit.pure[F]
173-
case Left(_) =>
174-
// Channel send failed, check if topic was closed
175-
state.get.map {
176-
case State.Closed() => Topic.closed
177-
case State.Active(_, _) => Topic.rightUnit
178-
}
164+
state.flatModify {
165+
case s @ State.Active(subs, _, n, false) =>
166+
val inc = n + 1
167+
val newState = s.copy(publishing = inc)
168+
169+
val sends = subs.foldLeft(F.pure(true)) { case (acc, (_, chan)) =>
170+
chan.send(a).map(_.isRight).map2(acc)(_ && _)
171+
}
172+
173+
val action = sends.flatMap { allSucceeded =>
174+
state.flatModify {
175+
case s @ State.Active(subs, _, n, closing) =>
176+
val dec = n - 1
177+
if (dec == 0 && closing) {
178+
val closeAction = foreach(subs)(_.close.void)
179+
(State.Closed(), closeAction >> publishersFinished.complete(()).void)
180+
} else {
181+
(s.copy(publishing = dec), F.unit)
179182
}
183+
case s @ State.Closed() => (s, F.unit)
184+
}.map { _ =>
185+
if (allSucceeded) Topic.rightUnit else Topic.closed
180186
}
181187
}
188+
(newState, action)
189+
190+
case s @ State.Active(_, _, _, true) =>
191+
(s, Topic.closed.pure[F])
192+
case s @ State.Closed() =>
193+
(s, Topic.closed.pure[F])
182194
}
183195

184196
def subscribeAwait(maxQueued: Int): Resource[F, Stream[F, A]] =
@@ -194,18 +206,20 @@ object Topic {
194206
def subscribeAwaitImpl(chan: Channel[F, A]): Resource[F, Stream[F, A]] = {
195207
val subscribe: F[Option[Long]] =
196208
state.flatModify {
197-
case State.Active(subs, nextId) =>
198-
val newState = State.Active(subs.updated(nextId, chan), nextId + 1)
209+
case s @ State.Active(subs, nextId, _, false) =>
210+
val newState = s.copy(subscribers = subs.updated(nextId, chan), nextId = nextId + 1)
199211
val action = subscriberCount.update(_ + 1)
200212
val result = Some(nextId)
201213
newState -> action.as(result)
214+
case s @ State.Active(_, _, _, true) =>
215+
s -> F.pure(None)
202216
case closed @ State.Closed() =>
203217
closed -> F.pure(None)
204218
}
205219

206220
def unsubscribe(id: Long): F[Unit] =
207221
state.flatModify {
208-
case State.Active(subs, nextId) =>
222+
case s @ State.Active(subs, nextId, _, _) =>
209223
// _After_ we remove the bounded channel for this
210224
// subscriber, we need to drain it to unblock to
211225
// publish loop which might have already enqueued
@@ -215,7 +229,7 @@ object Topic {
215229
chan.close >> chan.stream.compile.drain
216230
}
217231

218-
State.Active(subs - id, nextId) -> (drainChannel *> subscriberCount.update(_ - 1))
232+
s.copy(subscribers = subs - id) -> (drainChannel *> subscriberCount.update(_ - 1))
219233

220234
case closed @ State.Closed() =>
221235
closed -> F.unit
@@ -249,9 +263,15 @@ object Topic {
249263

250264
def close: F[Either[Topic.Closed, Unit]] =
251265
state.flatModify {
252-
case State.Active(subs, _) =>
253-
val action = foreach(subs)(_.close.void) *> signalClosure.complete(())
254-
(State.Closed(), action.as(Topic.rightUnit))
266+
case s @ State.Active(subs, _, n, false) =>
267+
if (n == 0) {
268+
val action = foreach(subs)(_.close.void) *> signalClosure.complete(())
269+
(State.Closed(), (action >> publishersFinished.complete(())).as(Topic.rightUnit))
270+
} else {
271+
(s.copy(closing = true), publishersFinished.get.as(Topic.rightUnit))
272+
}
273+
case s @ State.Active(_, _, _, true) =>
274+
(s, publishersFinished.get.as(Topic.rightUnit))
255275
case closed @ State.Closed() =>
256276
(closed, Topic.closed.pure[F])
257277
}
@@ -266,13 +286,15 @@ object Topic {
266286
private object State {
267287
case class Active[F[_], A](
268288
subscribers: LongMap[Channel[F, A]],
269-
nextId: Long
289+
nextId: Long,
290+
publishing: Long,
291+
closing: Boolean
270292
) extends State[F, A]
271293

272294
case class Closed[F[_], A]() extends State[F, A]
273295

274296
def initial[F[_], A]: State[F, A] =
275-
Active(LongMap.empty, 1L)
297+
Active(LongMap.empty, 1L, 0L, false)
276298
}
277299

278300
private final val closed: Either[Closed, Unit] = Left(Closed)

0 commit comments

Comments
 (0)