Skip to content

Commit ea78dbc

Browse files
committed
fix: wait for all flow-graph connections before returning from FlowGraph()
feat(test): improve tests, add one for CloningFlow
1 parent f6a91e0 commit ea78dbc

File tree

6 files changed

+104
-14
lines changed

6 files changed

+104
-14
lines changed
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
package dev.silenium.libs.flows.api
22

33
import kotlinx.coroutines.Job
4-
import kotlinx.coroutines.flow.launchIn
5-
import kotlinx.coroutines.flow.onEach
64

75
interface FlowGraphBuilder : FlowGraph {
8-
infix fun <T, P> Source<T, P>.connectTo(sink: Sink<T, P>): Job =
9-
flow.onEach(sink::submit).launchIn(this@FlowGraphBuilder)
6+
infix fun <T, P> Source<T, P>.connectTo(sink: Sink<T, P>): Result<Job>
7+
8+
suspend fun finalize(): Result<Unit>
109
}

src/main/kotlin/dev/silenium/libs/flows/impl/FlowGraphImpl.kt

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package dev.silenium.libs.flows.impl
22

33
import dev.silenium.libs.flows.api.*
4-
import kotlinx.coroutines.CoroutineScope
5-
import kotlinx.coroutines.Dispatchers
6-
import kotlinx.coroutines.cancel
4+
import kotlinx.coroutines.*
75
import kotlin.coroutines.CoroutineContext
86
import kotlin.reflect.KClass
97

@@ -98,16 +96,37 @@ internal class FlowGraphImpl(private val coroutineScope: CoroutineScope) :
9896
) : CoroutineContext.Element
9997
}
10098

101-
internal class FlowGraphBuilderImpl(private val flowGraph: FlowGraph) : FlowGraphBuilder, FlowGraph by flowGraph
99+
internal class FlowGraphBuilderImpl(private val flowGraph: FlowGraph) : FlowGraphBuilder, FlowGraph by flowGraph {
100+
private val connectionStarted = mutableSetOf<Job>()
101+
102+
override fun <T, P> Source<T, P>.connectTo(sink: Sink<T, P>): Result<Job> {
103+
outputMetadata.forEach { (pad, metadata) ->
104+
sink.configure(pad, metadata).onFailure {
105+
return Result.failure(IllegalStateException("Unable to configure input pad $pad of sink $sink", it))
106+
}
107+
}
108+
val started = CompletableDeferred<Unit>()
109+
return launch {
110+
started.complete(Unit)
111+
flow.collect(sink)
112+
}.also {
113+
connectionStarted.add(started)
114+
}.let { Result.success(it) }
115+
}
116+
117+
override suspend fun finalize(): Result<Unit> = runCatching {
118+
connectionStarted.joinAll()
119+
}
120+
}
102121

103122
internal fun FlowGraph.builder() = FlowGraphBuilderImpl(this)
104123

105-
fun FlowGraph(
124+
suspend fun FlowGraph(
106125
coroutineContext: CoroutineContext = Dispatchers.Default,
107126
block: FlowGraphBuilder.() -> Unit,
108-
): FlowGraph = FlowGraphImpl(coroutineContext).builder().apply(block)
127+
): FlowGraph = FlowGraphImpl(coroutineContext).builder().apply(block).apply { finalize() }
109128

110-
fun FlowGraph(
129+
suspend fun FlowGraph(
111130
coroutineScope: CoroutineScope,
112131
block: FlowGraphBuilder.() -> Unit,
113-
): FlowGraph = FlowGraphImpl(coroutineScope).builder().apply(block)
132+
): FlowGraph = FlowGraphImpl(coroutineScope).builder().apply(block).apply { finalize() }

src/test/kotlin/dev/silenium/libs/flows/base/SourceBaseTest.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@ class SourceBaseTest : FunSpec({
3030
val job = CoroutineScope(Dispatchers.Default).launch {
3131
bufferSource.flow.collect(decoder)
3232
}
33+
val started = CompletableDeferred<Unit>()
3334
val listAsync = async(Dispatchers.Default) {
35+
started.complete(Unit)
3436
decoder.flow.toList()
3537
}
38+
started.await()
3639
val bufs = inputs.map { input ->
3740
val base64 = Base64.encode(input.encodeToByteArray())
3841
val byteBuffer = ByteBuffer.allocateDirect(base64.length)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package dev.silenium.libs.flows.impl
2+
3+
import dev.silenium.libs.flows.api.ReferenceCounted
4+
import io.kotest.core.spec.style.FunSpec
5+
import io.kotest.matchers.nulls.shouldBeNull
6+
import io.kotest.matchers.nulls.shouldNotBeNull
7+
import io.kotest.matchers.shouldBe
8+
import kotlinx.coroutines.CompletableDeferred
9+
import kotlinx.coroutines.async
10+
import kotlinx.coroutines.flow.toList
11+
import java.util.concurrent.atomic.AtomicLong
12+
import java.util.concurrent.atomic.AtomicReference
13+
14+
class TestData(
15+
val buf: AtomicReference<ByteArray?>,
16+
private val refCount_: AtomicLong = AtomicLong(0L)
17+
) : ReferenceCounted<TestData> {
18+
constructor(data: ByteArray) : this(AtomicReference(data))
19+
20+
init {
21+
refCount_.incrementAndGet()
22+
}
23+
24+
val refCount: Long
25+
get() = refCount_.get()
26+
27+
override fun clone(): Result<TestData> {
28+
return Result.success(TestData(buf, refCount_))
29+
}
30+
31+
override fun close() {
32+
if (refCount_.decrementAndGet() == 0L) {
33+
buf.set(null)
34+
}
35+
}
36+
}
37+
38+
class CloningFlowTest : FunSpec({
39+
test("cloning flow clones properly") {
40+
val flow = CloningFlow<TestData>()
41+
42+
val started = CompletableDeferred<Unit>()
43+
val items = async {
44+
started.complete(Unit)
45+
flow.toList()
46+
}
47+
started.await()
48+
49+
val inputs = listOf(
50+
"Hello, World!",
51+
"some text",
52+
"Another text",
53+
)
54+
inputs.forEach { input ->
55+
val data = TestData(input.encodeToByteArray())
56+
flow.publish(data)
57+
data.close()
58+
}
59+
60+
flow.close()
61+
items.await().map { item ->
62+
item.buf.get().shouldNotBeNull().decodeToString().also {
63+
item.close()
64+
item.refCount shouldBe 0L
65+
item.buf.get().shouldBeNull()
66+
}
67+
}.toSet() shouldBe inputs.toSet()
68+
}
69+
})

src/test/kotlin/dev/silenium/libs/flows/impl/FlowGraphImplTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class FlowGraphImplTest : FunSpec({
1313
test("FlowGraphBuilder") {
1414
val graph = FlowGraph(CoroutineScope(Dispatchers.Default)) {
1515
val source = source(BufferSource<Base64Buffer, DataType>(0u to DataType.BASE64), "buffer-source")
16-
val sink = sink(BufferSink<ByteArray, DataType>(0u to DataType.PLAIN), "buffer-sink")
16+
val sink = sink(BufferSink<ByteArray, DataType>(), "buffer-sink")
1717
val decoder = transformer(Base64Decoder(), "base64-decoder")
1818
source.connectTo(decoder)
1919
decoder.connectTo(sink)

src/test/kotlin/dev/silenium/libs/flows/test/BufferSink.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class BufferSink<T, P>(vararg pads: Pair<UInt, P>) : Sink<T, P> {
1212

1313
private val buffer_: MutableMap<UInt, MutableList<FlowItem<T, P>>> = mutableMapOf()
1414
val buffer: Map<UInt, List<FlowItem<T, P>>> by ::buffer_
15-
val flow_ = MutableStateFlow<Map<UInt, List<FlowItem<T, P>>>>(emptyMap())
15+
private val flow_ = MutableStateFlow<Map<UInt, List<FlowItem<T, P>>>>(emptyMap())
1616
val flow: StateFlow<Map<UInt, List<FlowItem<T, P>>>> = flow_.asStateFlow()
1717

1818
override suspend fun submit(item: FlowItem<T, P>): Result<Unit> {

0 commit comments

Comments
 (0)