Skip to content

Commit 38f6ae0

Browse files
Merge pull request #1375 from square/zachklipp/workflow-multithreading
Fix workflow freezing thread safety.
2 parents 96997a2 + 006a0a2 commit 38f6ae0

File tree

10 files changed

+327
-36
lines changed

10 files changed

+327
-36
lines changed

workflow-runtime/src/appleMain/kotlin/com/squareup/workflow1/internal/Synchronization.apple.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
package com.squareup.workflow1.internal
22

3+
import kotlinx.cinterop.CPointer
4+
import kotlinx.cinterop.ExperimentalForeignApi
5+
import platform.Foundation.NSCopyingProtocol
36
import platform.Foundation.NSLock
7+
import platform.Foundation.NSThread
8+
import platform.Foundation.NSZone
9+
import platform.darwin.NSObject
410

11+
/**
12+
* Creates a lock that, after locking, must only be unlocked by the thread that acquired the lock.
13+
*
14+
* See the docs: https://developer.apple.com/documentation/foundation/nslock#overview
15+
*/
516
internal actual typealias Lock = NSLock
617

718
internal actual inline fun <R> Lock.withLock(block: () -> R): R {
@@ -12,3 +23,35 @@ internal actual inline fun <R> Lock.withLock(block: () -> R): R {
1223
unlock()
1324
}
1425
}
26+
27+
/**
28+
* Implementation of [ThreadLocal] that works in a similar way to Java's, based on a thread-specific
29+
* map/dictionary.
30+
*/
31+
internal actual class ThreadLocal<T>(
32+
private val initialValue: () -> T
33+
) : NSObject(), NSCopyingProtocol {
34+
35+
private val threadDictionary
36+
get() = NSThread.currentThread().threadDictionary
37+
38+
actual fun get(): T {
39+
@Suppress("UNCHECKED_CAST")
40+
return (threadDictionary.objectForKey(aKey = this) as T?)
41+
?: initialValue().also(::set)
42+
}
43+
44+
actual fun set(value: T) {
45+
threadDictionary.setObject(value, forKey = this)
46+
}
47+
48+
/**
49+
* [Docs](https://developer.apple.com/documentation/foundation/nscopying/copy(with:)) say [zone]
50+
* is unused.
51+
*/
52+
@OptIn(ExperimentalForeignApi::class)
53+
override fun copyWithZone(zone: CPointer<NSZone>?): Any = this
54+
}
55+
56+
internal actual fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T> =
57+
ThreadLocal(initialValue)

workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/RealRenderContext.kt

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,24 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
4646
}
4747

4848
/**
49-
* False during the current render call, set to true once this node is finished rendering.
49+
* False except while this [WorkflowNode] is running the workflow's `render` method.
5050
*
5151
* Used to:
52-
* - prevent modifications to this object after [freeze] is called.
53-
* - prevent sending to sinks before render returns.
52+
* - Prevent modifications to this object after [freeze] is called (e.g. [renderChild] calls).
53+
* Only allowed when this flag is true.
54+
* - Prevent sending to sinks before render returns. Only allowed when this flag is false.
55+
*
56+
* This is a [ThreadLocal] since we only care about preventing calls during rendering from the
57+
* thread that is actually doing the rendering. If a background thread happens to send something
58+
* into the sink, for example, while the main thread is rendering, it's not a violation.
5459
*/
55-
private var frozen = false
60+
private var performingRender by threadLocalOf { false }
5661

5762
override val actionSink: Sink<WorkflowAction<PropsT, StateT, OutputT>> get() = this
5863

5964
override fun send(value: WorkflowAction<PropsT, StateT, OutputT>) {
60-
if (!frozen) {
65+
// Can't send actions from render thread during render pass.
66+
if (performingRender) {
6167
throw UnsupportedOperationException(
6268
"Expected sink to not be sent to until after the render pass. " +
6369
"Received action: ${value.debuggingName}"
@@ -72,7 +78,7 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
7278
key: String,
7379
handler: (ChildOutputT) -> WorkflowAction<PropsT, StateT, OutputT>
7480
): ChildRenderingT {
75-
checkNotFrozen(child.identifier) {
81+
checkPerformingRender(child.identifier) {
7682
"renderChild(${child.identifier})"
7783
}
7884
return renderer.render(child, props, key, handler)
@@ -82,7 +88,7 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
8288
key: String,
8389
sideEffect: suspend CoroutineScope.() -> Unit
8490
) {
85-
checkNotFrozen(key) { "runningSideEffect($key)" }
91+
checkPerformingRender(key) { "runningSideEffect($key)" }
8692
sideEffectRunner.runningSideEffect(key, sideEffect)
8793
}
8894

@@ -92,23 +98,22 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
9298
vararg inputs: Any?,
9399
calculation: () -> ResultT
94100
): ResultT {
95-
checkNotFrozen(key) { "remember($key)" }
101+
checkPerformingRender(key) { "remember($key)" }
96102
return rememberStore.remember(key, resultType, inputs = inputs, calculation)
97103
}
98104

99105
/**
100106
* Freezes this context so that any further calls to this context will throw.
101107
*/
102108
fun freeze() {
103-
checkNotFrozen("freeze") { "freeze" }
104-
frozen = true
109+
performingRender = false
105110
}
106111

107112
/**
108113
* Unfreezes when the node is about to render() again.
109114
*/
110115
fun unfreeze() {
111-
frozen = false
116+
performingRender = true
112117
}
113118

114119
/**
@@ -117,8 +122,10 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
117122
*
118123
* @see checkWithKey
119124
*/
120-
private inline fun checkNotFrozen(stackTraceKey: Any, lazyMessage: () -> Any) =
121-
checkWithKey(!frozen, stackTraceKey) {
122-
"RenderContext cannot be used after render method returns: ${lazyMessage()}"
123-
}
125+
private inline fun checkPerformingRender(
126+
stackTraceKey: Any,
127+
lazyMessage: () -> Any
128+
) = checkWithKey(performingRender, stackTraceKey) {
129+
"RenderContext cannot be used after render method returns: ${lazyMessage()}"
130+
}
124131
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,29 @@
11
package com.squareup.workflow1.internal
22

3+
import kotlin.reflect.KProperty
4+
35
internal expect class Lock()
46

57
internal expect inline fun <R> Lock.withLock(block: () -> R): R
8+
9+
internal expect class ThreadLocal<T> {
10+
fun get(): T
11+
fun set(value: T)
12+
}
13+
14+
internal expect fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T>
15+
16+
@Suppress("NOTHING_TO_INLINE")
17+
internal inline operator fun <T> ThreadLocal<T>.getValue(
18+
receiver: Any?,
19+
property: KProperty<*>
20+
): T = get()
21+
22+
@Suppress("NOTHING_TO_INLINE")
23+
internal inline operator fun <T> ThreadLocal<T>.setValue(
24+
receiver: Any?,
25+
property: KProperty<*>,
26+
value: T
27+
) {
28+
set(value)
29+
}

workflow-runtime/src/commonTest/kotlin/com/squareup/workflow1/internal/RealRenderContextTest.kt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,10 @@ internal class RealRenderContextTest {
220220

221221
val child = Workflow.stateless<Unit, Nothing, Unit> { fail() }
222222
assertFailsWith<IllegalStateException> { context.renderChild(child) }
223-
assertFailsWith<IllegalStateException> { context.freeze() }
224223
assertFailsWith<IllegalStateException> { context.remember("key", typeOf<String>()) {} }
224+
225+
// Freeze is the exception, it's idempotent and can be called again.
226+
context.freeze()
225227
}
226228

227229
private fun createdPoisonedContext(): RealRenderContext<String, String, String> {
@@ -234,7 +236,9 @@ internal class RealRenderContextTest {
234236
eventActionsChannel,
235237
workflowTracer = null,
236238
runtimeConfig = emptySet(),
237-
)
239+
).apply {
240+
unfreeze()
241+
}
238242
}
239243

240244
private fun createTestContext(): RealRenderContext<String, String, String> {
@@ -247,6 +251,8 @@ internal class RealRenderContextTest {
247251
eventActionsChannel,
248252
workflowTracer = null,
249253
runtimeConfig = emptySet(),
250-
)
254+
).apply {
255+
unfreeze()
256+
}
251257
}
252258
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package com.squareup.workflow1.internal
2+
3+
import platform.Foundation.NSCondition
4+
import platform.Foundation.NSThread
5+
import kotlin.concurrent.Volatile
6+
import kotlin.test.Test
7+
import kotlin.test.assertEquals
8+
9+
class ThreadLocalTest {
10+
11+
@Volatile
12+
private var valueFromThread: Int = -1
13+
14+
@Test fun initialValue() {
15+
val threadLocal = ThreadLocal(initialValue = { 42 })
16+
assertEquals(42, threadLocal.get())
17+
}
18+
19+
@Test fun settingValue() {
20+
val threadLocal = ThreadLocal(initialValue = { 42 })
21+
threadLocal.set(0)
22+
assertEquals(0, threadLocal.get())
23+
}
24+
25+
@Test fun initialValue_inSeparateThread_afterChanging() {
26+
val threadLocal = ThreadLocal(initialValue = { 42 })
27+
threadLocal.set(0)
28+
29+
val thread = NSThread {
30+
valueFromThread = threadLocal.get()
31+
}
32+
thread.start()
33+
thread.join()
34+
35+
assertEquals(42, valueFromThread)
36+
}
37+
38+
@Test fun set_fromDifferentThreads_doNotConflict() {
39+
val threadLocal = ThreadLocal(initialValue = { 0 })
40+
// threadStartedLatch and firstReadLatch together form a barrier: the allow the background
41+
// to start up and get to the same point as the test thread, just before writing to the
42+
// ThreadLocal, before allowing both threads to perform the write as close to the same time as
43+
// possible.
44+
val threadStartedLatch = NSCondition()
45+
val firstReadLatch = NSCondition()
46+
val firstReadDoneLatch = NSCondition()
47+
val secondReadLatch = NSCondition()
48+
49+
val thread = NSThread {
50+
// Wait on the barrier to sync with the test thread.
51+
threadStartedLatch.signal()
52+
firstReadLatch.wait()
53+
threadLocal.set(1)
54+
55+
// Ensure we can see our read immediately, then wait for the test thread to verify. This races
56+
// with the set(2) in the test thread, but that's fine. We'll double-check the value later.
57+
valueFromThread = threadLocal.get()
58+
firstReadDoneLatch.signal()
59+
secondReadLatch.wait()
60+
61+
// Read one last time since now the test thread's second write is done.
62+
valueFromThread = threadLocal.get()
63+
}
64+
thread.start()
65+
66+
// Wait for the other thread to start, then both threads set the value to something different
67+
// at the same time.
68+
threadStartedLatch.wait()
69+
firstReadLatch.signal()
70+
threadLocal.set(2)
71+
72+
// Wait for the background thread to finish setting value, then ensure that both threads see
73+
// independent values.
74+
firstReadDoneLatch.wait()
75+
assertEquals(1, valueFromThread)
76+
assertEquals(2, threadLocal.get())
77+
78+
// Change the value in this thread then read it again from the background thread.
79+
threadLocal.set(3)
80+
secondReadLatch.signal()
81+
thread.join()
82+
assertEquals(1, valueFromThread)
83+
}
84+
85+
private fun NSThread.join() {
86+
while (!isFinished()) {
87+
// Avoid being optimized out.
88+
// Time interval is in seconds.
89+
NSThread.sleepForTimeInterval(1.0 / 1000)
90+
}
91+
}
92+
}

workflow-runtime/src/jsMain/kotlin/com/squareup/workflow1/internal/Synchronization.js.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,13 @@ package com.squareup.workflow1.internal
55
internal actual typealias Lock = Any
66

77
internal actual inline fun <R> Lock.withLock(block: () -> R): R = block()
8+
9+
internal actual class ThreadLocal<T>(private var value: T) {
10+
actual fun get(): T = value
11+
actual fun set(value: T) {
12+
this.value = value
13+
}
14+
}
15+
16+
internal actual fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T> =
17+
ThreadLocal(initialValue())

workflow-runtime/src/jvmMain/kotlin/com/squareup/workflow1/internal/Synchronization.jvm.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,8 @@ package com.squareup.workflow1.internal
33
internal actual typealias Lock = Any
44

55
internal actual inline fun <R> Lock.withLock(block: () -> R): R = synchronized(this, block)
6+
7+
internal actual typealias ThreadLocal<T> = java.lang.ThreadLocal<T>
8+
9+
internal actual fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T> =
10+
ThreadLocal.withInitial(initialValue)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package com.squareup.workflow1
2+
3+
import java.util.concurrent.CountDownLatch
4+
5+
/**
6+
* Returns the maximum number of threads that can be run in parallel on the host system, rounded
7+
* down to the nearest even number, and at least 2.
8+
*/
9+
internal fun calculateSaturatingTestThreadCount(minThreads: Int) =
10+
Runtime.getRuntime().availableProcessors().let {
11+
if (it.mod(2) != 0) it - 1 else it
12+
}.coerceAtLeast(minThreads)
13+
14+
/**
15+
* Calls [CountDownLatch.await] in a loop until count is zero, even if the thread gets
16+
* interrupted.
17+
*/
18+
@Suppress("CheckResult")
19+
internal fun CountDownLatch.awaitUntilDone() {
20+
while (count > 0) {
21+
try {
22+
await()
23+
} catch (e: InterruptedException) {
24+
// Continue
25+
}
26+
}
27+
}

0 commit comments

Comments
 (0)