Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions core/src/main/scala/ox/flow/internal/groupByImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,14 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
case class ChildDone(v: V)

// Running a pending child flow, after another has completed as done
def runChild_ifPending(state: GroupByState, childDone: Sink[ChildDone], childOutput: Sink[U])(using OxUnsupervised): GroupByState =
def runChild_ifPending(state: GroupByState, childOutput: Sink[U | ChildDone])(using OxUnsupervised): GroupByState =
state.pendingFromParent match
case Some((t, v, counter)) =>
sendToChild_orRunChild_orBuffer(state.copy(pendingFromParent = None), childDone, childOutput, t, v, counter)
case None => state
case Some((t, v, counter)) => sendToChild_orRunChild_orBuffer(state.copy(pendingFromParent = None), childOutput, t, v, counter)
case None => state

def sendToChild_orRunChild_orBuffer(
state: GroupByState,
childDone: Sink[ChildDone],
childOutput: Sink[U],
childOutput: Sink[U | ChildDone],
t: T,
v: V,
counter: Long
Expand All @@ -69,14 +67,14 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic

case None if s.children.size < parallelism =>
// Starting a new child flow, running in the background; the child flow receives values via a channel,
// and feeds its output to `childOutput`. Done signals are forwarded to `childDone`; elements & errors
// are propagated to `childOutput`.
// and feeds its output to `childOutput`. Done signals are propagated as values, errors are propagated
// as channel errors.
val childChannel = BufferCapacity.newChannel[T]
s = s.withChildAdded(v, childChannel)

forkUnsupervised:
childFlowTransform(v)(Flow.fromSource(childChannel))
.onDone(childDone.sendOrClosed(ChildDone(v)).discard)
.onDone(childOutput.sendOrClosed(ChildDone(v)).discard)
// When the child flow is done, making sure that the source channel becomes closed as well
// otherwise, we'd be risking a deadlock, if there are `childChannel.send`-s pending, and the
// buffer is full; if the channel is already closed, this is a no-op.
Expand All @@ -92,7 +90,7 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
s = s.withPendingFromParent(t, v, counter)

// Completing as done the child flow which didn't receive an element for the longest time. After
// the flow completes, it will send `ChildDone` to `childDone`.
// the flow completes, it will send `ChildDone` to `childOutput`.
s = s.withoutLongestInactiveChild.pipe { (vOpt, s2) =>
vOpt.foreach(s.children(_).done())
s2
Expand All @@ -107,11 +105,9 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
// Channel where all elements emitted by child flows will be sent; we use such a collective channel instead of
// enumerating all child channels in the main `select`, as `select`s don't scale well with the number of
// clauses. The elements from this channel are then emitted by the returned flow.
val childOutput = BufferCapacity.newChannel[U]

// Channel where completion of children is signalled (because the parent is complete, or the parallelism limit
// is reached).
val childDone = Channel.unlimited[ChildDone]
// Completion of children (when the parent is complete, or the parallelism limit is reached) is signalled on
// this channel as well.
val childOutput = BufferCapacity.newChannel[U | ChildDone]

// Parent channel, from which we receive as long as it's not done, and only when a child flow isn't pending
// creation (see below). As the receive is conditional, the errors that occur on this channel are also
Expand All @@ -132,12 +128,12 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
// values before marking a child as done.
val pool =
if state.shouldReceiveFromParentChannel
then List(childOutput, childDone)
else List(childOutput, childDone, parentChannel)
then List(childOutput)
else List(childOutput, parentChannel)

selectOrClosed(pool) match
case ChannelClosed.Done =>
// Only the parent can be done; child completion is signalled via a value in `childDone`.
// Only the parent can be done; child completion is signalled via a value in `childOutput`.
state = state.withParentDone(isSourceDone(parentChannel))
assert(state.parentDone)

Expand All @@ -153,7 +149,7 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic

case FromParent(t) =>
state = state.withFromParentCounterIncremented
state = sendToChild_orRunChild_orBuffer(state, childDone, childOutput, t, predicate(t), state.fromParentCounter)
state = sendToChild_orRunChild_orBuffer(state, childOutput, t, predicate(t), state.fromParentCounter)

case ChildDone(v) =>
state = state.withChildRemoved(v)
Expand All @@ -168,7 +164,7 @@ private[flow] def groupByImpl[T, V, U](parent: Flow[T], parallelism: Int, predic
"childFlowTransform), while this is not allowed (see documentation for details)"
)

state = runChild_ifPending(state, childDone, childOutput)
state = runChild_ifPending(state, childOutput)

case u: U @unchecked => emit(u) // forwarding from `childOutput`
end match
Expand Down
6 changes: 6 additions & 0 deletions core/src/test/scala/ox/flow/FlowOpsGroupByTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ class FlowOpsGroupByTest extends AnyFlatSpec with Matchers:
it should "handle single-element flow" in:
Flow.fromValues(42).groupBy(10, _ % 10)(v => f => f).runToList() shouldBe List(42)

it should "handle single-element flow (stress test)" in:
// this test failed with a previous implementation which used separate channels for receiving child elements
// (childOutput) and for signalling children completion; the select then sometimes chose the done clause before
// the value clause, which resulted in dropping the value
for i <- 1 to 100000 do Flow.fromValues(42).groupBy(10, _ % 10)(v => f => f).runToList() shouldBe List(42)

it should "create simple groups without reaching parallelism limit" in:
case class Group(v: Int, values: List[Int])

Expand Down