Skip to content

Commit 45ae00f

Browse files
committed
Propagate the context in Coroutines transactions
This commit ensures that CoroutineContext is properly propagated in transactional suspending functions. Both annotation and functional variants are supported. Closes gh-27308
1 parent 3e2f58c commit 45ae00f

File tree

4 files changed

+128
-11
lines changed

4 files changed

+128
-11
lines changed

spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -22,6 +22,8 @@
2222

2323
import io.vavr.control.Try;
2424
import kotlin.coroutines.Continuation;
25+
import kotlin.coroutines.CoroutineContext;
26+
import kotlinx.coroutines.Job;
2527
import kotlinx.coroutines.reactive.AwaitKt;
2628
import kotlinx.coroutines.reactive.ReactiveFlowKt;
2729
import org.apache.commons.logging.Log;
@@ -363,7 +365,7 @@ protected Object invokeWithinTransaction(Method method, @Nullable Class<?> targe
363365

364366
InvocationCallback callback = invocation;
365367
if (corInv != null) {
366-
callback = () -> CoroutinesUtils.invokeSuspendingFunction(method, corInv.getTarget(), corInv.getArguments());
368+
callback = () -> KotlinDelegate.invokeSuspendingFunction(method, corInv);
367369
}
368370
Object result = txSupport.invokeWithinTransaction(method, targetClass, callback, txAttr, (ReactiveTransactionManager) tm);
369371
if (corInv != null) {
@@ -883,6 +885,12 @@ private static Object asFlow(Publisher<?> publisher) {
883885
private static Object awaitSingleOrNull(Publisher<?> publisher, Object continuation) {
884886
return AwaitKt.awaitSingleOrNull(publisher, (Continuation<Object>) continuation);
885887
}
888+
889+
public static Publisher<?> invokeSuspendingFunction(Method method, CoroutinesInvocationCallback callback) {
890+
CoroutineContext coroutineContext = ((Continuation<?>) callback.getContinuation()).getContext().minusKey(Job.Key);
891+
return CoroutinesUtils.invokeSuspendingFunction(coroutineContext, method, callback.getTarget(), callback.getArguments());
892+
}
893+
886894
}
887895

888896

spring-tx/src/main/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensions.kt

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,26 @@
1616

1717
package org.springframework.transaction.reactive
1818

19-
import java.util.Optional
20-
import kotlinx.coroutines.Dispatchers
19+
import kotlinx.coroutines.Job
20+
import kotlinx.coroutines.currentCoroutineContext
2121
import kotlinx.coroutines.flow.Flow
2222
import kotlinx.coroutines.reactive.asFlow
2323
import kotlinx.coroutines.reactive.awaitLast
2424
import kotlinx.coroutines.reactor.asFlux
2525
import kotlinx.coroutines.reactor.mono
2626
import org.springframework.transaction.ReactiveTransaction
27+
import java.util.*
28+
import kotlin.coroutines.CoroutineContext
29+
import kotlin.coroutines.EmptyCoroutineContext
2730

2831
/**
2932
* Coroutines variant of [TransactionalOperator.transactional] as a [Flow] extension.
3033
*
3134
* @author Sebastien Deleuze
3235
* @since 5.2
3336
*/
34-
fun <T : Any> Flow<T>.transactional(operator: TransactionalOperator): Flow<T> =
35-
operator.transactional(asFlux()).asFlow()
37+
fun <T : Any> Flow<T>.transactional(operator: TransactionalOperator, context: CoroutineContext = EmptyCoroutineContext): Flow<T> =
38+
operator.transactional(asFlux(context)).asFlow()
3639

3740
/**
3841
* Coroutines variant of [TransactionalOperator.execute] with a suspending lambda
@@ -42,6 +45,8 @@ fun <T : Any> Flow<T>.transactional(operator: TransactionalOperator): Flow<T> =
4245
* @author Mark Paluch
4346
* @since 5.2
4447
*/
45-
suspend fun <T> TransactionalOperator.executeAndAwait(f: suspend (ReactiveTransaction) -> T): T =
46-
execute { status -> mono(Dispatchers.Unconfined) { f(status) } }.map { value -> Optional.ofNullable(value) }
48+
suspend fun <T> TransactionalOperator.executeAndAwait(f: suspend (ReactiveTransaction) -> T): T {
49+
val context = currentCoroutineContext().minusKey(Job.Key)
50+
return execute { status -> mono(context) { f(status) } }.map { value -> Optional.ofNullable(value) }
4751
.defaultIfEmpty(Optional.empty()).awaitLast().orElse(null)
52+
}

spring-tx/src/test/kotlin/org/springframework/transaction/annotation/CoroutinesAnnotationTransactionInterceptorTests.kt

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ import kotlinx.coroutines.flow.Flow
2121
import kotlinx.coroutines.flow.flow
2222
import kotlinx.coroutines.flow.toList
2323
import kotlinx.coroutines.runBlocking
24-
import org.assertj.core.api.Assertions.assertThat
25-
import org.assertj.core.api.Assertions.fail
24+
import kotlinx.coroutines.withContext
25+
import org.assertj.core.api.Assertions.*
2626
import org.junit.jupiter.api.Test
2727
import org.springframework.aop.framework.ProxyFactory
2828
import org.springframework.transaction.interceptor.TransactionInterceptor
2929
import org.springframework.transaction.testfixture.ReactiveCallCountingTransactionManager
30+
import kotlin.coroutines.AbstractCoroutineContextElement
31+
import kotlin.coroutines.CoroutineContext
32+
import kotlin.coroutines.coroutineContext
3033

3134
/**
3235
* @author Sebastien Deleuze
@@ -118,6 +121,36 @@ class CoroutinesAnnotationTransactionInterceptorTests {
118121
assertReactiveGetTransactionAndCommitCount(1)
119122
}
120123

124+
@Test
125+
fun suspendingValueSuccessWithContext() {
126+
val proxyFactory = ProxyFactory()
127+
proxyFactory.setTarget(TestWithCoroutines())
128+
proxyFactory.addAdvice(TransactionInterceptor(rtm, source))
129+
val proxy = proxyFactory.proxy as TestWithCoroutines
130+
assertThat(runBlocking {
131+
withExampleContext("context") {
132+
proxy.suspendingValueSuccessWithContext()
133+
}
134+
}).isEqualTo("context")
135+
assertReactiveGetTransactionAndCommitCount(1)
136+
}
137+
138+
@Test
139+
fun suspendingValueFailureWithContext() {
140+
val proxyFactory = ProxyFactory()
141+
proxyFactory.setTarget(TestWithCoroutines())
142+
proxyFactory.addAdvice(TransactionInterceptor(rtm, source))
143+
val proxy = proxyFactory.proxy as TestWithCoroutines
144+
assertThatIllegalStateException().isThrownBy {
145+
runBlocking {
146+
withExampleContext("context") {
147+
proxy.suspendingValueFailureWithContext()
148+
}
149+
}
150+
}.withMessage("context")
151+
assertReactiveGetTransactionAndRollbackCount(1)
152+
}
153+
121154
private fun assertReactiveGetTransactionAndCommitCount(expectedCount: Int) {
122155
assertThat(rtm.begun).isEqualTo(expectedCount)
123156
assertThat(rtm.commits).isEqualTo(expectedCount)
@@ -166,5 +199,27 @@ class CoroutinesAnnotationTransactionInterceptorTests {
166199
emit("foo")
167200
}
168201
}
202+
203+
open suspend fun suspendingValueSuccessWithContext(): String {
204+
delay(10)
205+
return coroutineContext[ExampleContext.Key].toString()
206+
}
207+
208+
open suspend fun suspendingValueFailureWithContext(): String {
209+
delay(10)
210+
throw IllegalStateException(coroutineContext[ExampleContext.Key].toString())
211+
}
169212
}
170213
}
214+
215+
data class ExampleContext(val value: String) : AbstractCoroutineContextElement(ExampleContext) {
216+
217+
companion object Key : CoroutineContext.Key<ExampleContext>
218+
219+
override fun toString(): String = value
220+
}
221+
222+
private suspend fun withExampleContext(inputValue: String, f: suspend () -> String) =
223+
withContext(ExampleContext(inputValue)) {
224+
f()
225+
}

spring-tx/src/test/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensionsTests.kt

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,13 +16,16 @@
1616

1717
package org.springframework.transaction.reactive
1818

19+
import kotlinx.coroutines.currentCoroutineContext
1920
import kotlinx.coroutines.delay
2021
import kotlinx.coroutines.flow.flow
2122
import kotlinx.coroutines.flow.toList
2223
import kotlinx.coroutines.runBlocking
2324
import org.assertj.core.api.Assertions.assertThat
2425
import org.junit.jupiter.api.Test
2526
import org.springframework.transaction.support.DefaultTransactionDefinition
27+
import kotlin.coroutines.AbstractCoroutineContextElement
28+
import kotlin.coroutines.CoroutineContext
2629

2730
class TransactionalOperatorExtensionsTests {
2831

@@ -107,4 +110,50 @@ class TransactionalOperatorExtensionsTests {
107110
}
108111
}
109112
}
113+
114+
@Test
115+
fun coroutineContextWithSuspendingFunction() {
116+
val operator = TransactionalOperator.create(tm, DefaultTransactionDefinition())
117+
runBlocking(User(role = "admin")) {
118+
try {
119+
operator.executeAndAwait {
120+
delay(1)
121+
val currentUser = currentCoroutineContext()[User]
122+
assertThat(currentUser).isNotNull()
123+
assertThat(currentUser!!.role).isEqualTo("admin")
124+
throw IllegalStateException()
125+
}
126+
} catch (e: IllegalStateException) {
127+
assertThat(tm.commit).isFalse()
128+
assertThat(tm.rollback).isTrue()
129+
return@runBlocking
130+
}
131+
}
132+
}
133+
134+
@Test
135+
fun coroutineContextWithFlow() {
136+
val operator = TransactionalOperator.create(tm, DefaultTransactionDefinition())
137+
val flow = flow<Int> {
138+
delay(1)
139+
val currentUser = currentCoroutineContext()[User]
140+
assertThat(currentUser).isNotNull()
141+
assertThat(currentUser!!.role).isEqualTo("admin")
142+
throw IllegalStateException()
143+
}
144+
runBlocking(User(role = "admin")) {
145+
try {
146+
flow.transactional(operator, coroutineContext).toList()
147+
} catch (e: IllegalStateException) {
148+
assertThat(tm.commit).isFalse()
149+
assertThat(tm.rollback).isTrue()
150+
return@runBlocking
151+
}
152+
}
153+
}
154+
155+
156+
private data class User(val role: String) : AbstractCoroutineContextElement(User) {
157+
companion object Key : CoroutineContext.Key<User>
158+
}
110159
}

0 commit comments

Comments
 (0)