Skip to content

Commit 1c6c69b

Browse files
authored
Fix data race in Flow.groupBy (#346)
1 parent ee75030 commit 1c6c69b

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

core/src/main/scala/ox/flow/internal/groupByImpl.scala

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,14 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
4747
case class ChildDone(v: V)
4848

4949
// Running a pending child flow, after another has completed as done
50-
def runChild_ifPending(state: GroupByState, childDone: Sink[ChildDone], childOutput: Sink[U])(using OxUnsupervised): GroupByState =
50+
def runChild_ifPending(state: GroupByState, childOutput: Sink[U | ChildDone])(using OxUnsupervised): GroupByState =
5151
state.pendingFromParent match
52-
case Some((t, v, counter)) =>
53-
sendToChild_orRunChild_orBuffer(state.copy(pendingFromParent = None), childDone, childOutput, t, v, counter)
54-
case None => state
52+
case Some((t, v, counter)) => sendToChild_orRunChild_orBuffer(state.copy(pendingFromParent = None), childOutput, t, v, counter)
53+
case None => state
5554

5655
def sendToChild_orRunChild_orBuffer(
5756
state: GroupByState,
58-
childDone: Sink[ChildDone],
59-
childOutput: Sink[U],
57+
childOutput: Sink[U | ChildDone],
6058
t: T,
6159
v: V,
6260
counter: Long
@@ -69,14 +67,14 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
6967

7068
case None if s.children.size < parallelism =>
7169
// Starting a new child flow, running in the background; the child flow receives values via a channel,
72-
// and feeds its output to `childOutput`. Done signals are forwarded to `childDone`; elements & errors
73-
// are propagated to `childOutput`.
70+
// and feeds its output to `childOutput`. Done signals are propagated as values, errors are propagated
71+
// as channel errors.
7472
val childChannel = BufferCapacity.newChannel[T]
7573
s = s.withChildAdded(v, childChannel)
7674

7775
forkUnsupervised:
7876
childFlowTransform(v)(Flow.fromSource(childChannel))
79-
.onDone(childDone.sendOrClosed(ChildDone(v)).discard)
77+
.onDone(childOutput.sendOrClosed(ChildDone(v)).discard)
8078
// When the child flow is done, making sure that the source channel becomes closed as well
8179
// otherwise, we'd be risking a deadlock, if there are `childChannel.send`-s pending, and the
8280
// buffer is full; if the channel is already closed, this is a no-op.
@@ -92,7 +90,7 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
9290
s = s.withPendingFromParent(t, v, counter)
9391

9492
// Completing as done the child flow which didn't receive an element for the longest time. After
95-
// the flow completes, it will send `ChildDone` to `childDone`.
93+
// the flow completes, it will send `ChildDone` to `childOutput`.
9694
s = s.withoutLongestInactiveChild.pipe { (vOpt, s2) =>
9795
vOpt.foreach(s.children(_).done())
9896
s2
@@ -107,11 +105,9 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
107105
// Channel where all elements emitted by child flows will be sent; we use such a collective channel instead of
108106
// enumerating all child channels in the main `select`, as `select`s don't scale well with the number of
109107
// clauses. The elements from this channel are then emitted by the returned flow.
110-
val childOutput = BufferCapacity.newChannel[U]
111-
112-
// Channel where completion of children is signalled (because the parent is complete, or the parallelism limit
113-
// is reached).
114-
val childDone = Channel.unlimited[ChildDone]
108+
// Completion of children (when the parent is complete, or the parallelism limit is reached) is signalled on
109+
// this channel as well.
110+
val childOutput = BufferCapacity.newChannel[U | ChildDone]
115111

116112
// Parent channel, from which we receive as long as it's not done, and only when a child flow isn't pending
117113
// creation (see below). As the receive is conditional, the errors that occur on this channel are also
@@ -132,12 +128,12 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
132128
// values before marking a child as done.
133129
val pool =
134130
if state.shouldReceiveFromParentChannel
135-
then List(childOutput, childDone)
136-
else List(childOutput, childDone, parentChannel)
131+
then List(childOutput)
132+
else List(childOutput, parentChannel)
137133

138134
selectOrClosed(pool) match
139135
case ChannelClosed.Done =>
140-
// Only the parent can be done; child completion is signalled via a value in `childDone`.
136+
// Only the parent can be done; child completion is signalled via a value in `childOutput`.
141137
state = state.withParentDone(isSourceDone(parentChannel))
142138
assert(state.parentDone)
143139

@@ -153,7 +149,7 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
153149

154150
case FromParent(t) =>
155151
state = state.withFromParentCounterIncremented
156-
state = sendToChild_orRunChild_orBuffer(state, childDone, childOutput, t, predicate(t), state.fromParentCounter)
152+
state = sendToChild_orRunChild_orBuffer(state, childOutput, t, predicate(t), state.fromParentCounter)
157153

158154
case ChildDone(v) =>
159155
state = state.withChildRemoved(v)
@@ -168,7 +164,7 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
168164
"childFlowTransform), while this is not allowed (see documentation for details)"
169165
)
170166

171-
state = runChild_ifPending(state, childDone, childOutput)
167+
state = runChild_ifPending(state, childOutput)
172168

173169
case u: U @unchecked => emit(u) // forwarding from `childOutput`
174170
end match

core/src/test/scala/ox/flow/FlowOpsGroupByTest.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ class FlowOpsGroupByTest extends AnyFlatSpec with Matchers:
2121
it should "handle single-element flow" in:
2222
Flow.fromValues(42).groupBy(10, _ % 10)(v => f => f).runToList() shouldBe List(42)
2323

24+
it should "handle single-element flow (stress test)" in:
25+
// this test failed with a previous implementation which used separate channels for receiving child elements
26+
// (childOutput) and for signalling children completion; the select then sometimes chose the done clause before
27+
// the value clause, which resulted in dropping the value
28+
for i <- 1 to 100000 do Flow.fromValues(42).groupBy(10, _ % 10)(v => f => f).runToList() shouldBe List(42)
29+
2430
it should "create simple groups without reaching parallelism limit" in:
2531
case class Group(v: Int, values: List[Int])
2632

0 commit comments

Comments
 (0)