Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@

import org.jspecify.annotations.Nullable;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ReactiveAdapter;
Expand All @@ -54,6 +56,8 @@
*/
final class RSocketServiceMethod {

private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow";

private final Method method;

private final MethodParameter[] parameters;
Expand Down Expand Up @@ -82,6 +86,10 @@ private static MethodParameter[] initMethodParameters(Method method) {
if (count == 0) {
return new MethodParameter[0];
}
if (KotlinDetector.isSuspendingFunction(method)) {
count -= 1;
}

DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
MethodParameter[] parameters = new MethodParameter[count];
for (int i = 0; i < count; i++) {
Expand Down Expand Up @@ -129,10 +137,19 @@ private static Function<RSocketRequestValues, Object> initResponseFunction(

MethodParameter returnParam = new MethodParameter(method, -1);
Class<?> returnType = returnParam.getParameterType();
boolean isFlowReturnType = COROUTINES_FLOW_CLASS_NAME.equals(returnType.getName());
boolean isUnwrapped = KotlinDetector.isSuspendingFunction(method) && !isFlowReturnType;
if (isUnwrapped) {
returnType = Mono.class;
}
else if (isFlowReturnType) {
returnType = Flux.class;
}

ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);

MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
Class<?> actualType = actualParam.getNestedParameterType();
Class<?> actualType = isUnwrapped ? actualParam.getParameterType() : actualParam.getNestedParameterType();

Function<RSocketRequestValues, Publisher<?>> responseFunction;
if (ClassUtils.isVoidType(actualType) || (reactiveAdapter != null && reactiveAdapter.isNoValue())) {
Expand All @@ -147,7 +164,8 @@ else if (reactiveAdapter == null) {
}
else {
ParameterizedTypeReference<?> payloadType =
ParameterizedTypeReference.forType(actualParam.getNestedGenericParameterType());
ParameterizedTypeReference.forType(isUnwrapped ? actualParam.getGenericParameterType() :
actualParam.getNestedGenericParameterType());

responseFunction = values -> (
reactiveAdapter.isMultiValue() ?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import org.springframework.aop.framework.ProxyFactory;
import org.springframework.aop.framework.ReflectiveMethodInvocation;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AnnotatedElementUtils;
Expand Down Expand Up @@ -246,7 +247,9 @@ private ServiceMethodInterceptor(List<RSocketServiceMethod> methods) {
Method method = invocation.getMethod();
RSocketServiceMethod serviceMethod = this.serviceMethods.get(method);
if (serviceMethod != null) {
return serviceMethod.invoke(invocation.getArguments());
@Nullable Object[] arguments = KotlinDetector.isSuspendingFunction(method) ?
resolveCoroutinesArguments(invocation.getArguments()) : invocation.getArguments();
return serviceMethod.invoke(arguments);
}
if (method.isDefault()) {
if (invocation instanceof ReflectiveMethodInvocation reflectiveMethodInvocation) {
Expand All @@ -256,6 +259,12 @@ private ServiceMethodInterceptor(List<RSocketServiceMethod> methods) {
}
throw new IllegalStateException("Unexpected method invocation: " + method);
}

private static Object[] resolveCoroutinesArguments(@Nullable Object[] args) {
Object[] functionArgs = new Object[args.length - 1];
System.arraycopy(args, 0, functionArgs, 0, args.length - 1);
return functionArgs;
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright 2002-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.messaging.rsocket.service

import io.rsocket.util.DefaultPayload
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.reactive.asFlow
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.messaging.rsocket.RSocketRequester
import org.springframework.messaging.rsocket.RSocketStrategies
import org.springframework.messaging.rsocket.TestRSocket
import org.springframework.util.MimeTypeUtils.TEXT_PLAIN
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono

/**
* Kotlin tests for [RSocketServiceMethod].
*
* @author Dmitry Sulman
*/
class RSocketServiceMethodKotlinTests {

private lateinit var rsocket: TestRSocket

private lateinit var proxyFactory: RSocketServiceProxyFactory

@BeforeEach
fun setUp() {
rsocket = TestRSocket()
val requester = RSocketRequester.wrap(rsocket, TEXT_PLAIN, TEXT_PLAIN, RSocketStrategies.create())
proxyFactory = RSocketServiceProxyFactory.builder(requester).build()
}

@Test
fun fireAndForget(): Unit = runBlocking {
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)

val requestPayload = "request"
service.fireAndForget(requestPayload)

assertThat(rsocket.savedMethodName).isEqualTo("fireAndForget")
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("ff")
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
}

@Test
fun requestResponse(): Unit = runBlocking {
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)

val requestPayload = "request"
val responsePayload = "response"
rsocket.setPayloadMonoToReturn(Mono.just(DefaultPayload.create(responsePayload)))
val response = service.requestResponse(requestPayload)

assertThat(response).isEqualTo(responsePayload)
assertThat(rsocket.savedMethodName).isEqualTo("requestResponse")
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rr")
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
}

@Test
fun requestStream(): Unit = runBlocking {
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)

val requestPayload = "request"
val responsePayload1 = "response1"
val responsePayload2 = "response2"
rsocket.setPayloadFluxToReturn(
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
val response = service.requestStream(requestPayload).toList()

assertThat(response).containsExactly(responsePayload1, responsePayload2)
assertThat(rsocket.savedMethodName).isEqualTo("requestStream")
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rs")
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
}

@Test
fun requestChannel(): Unit = runBlocking {
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)

val requestPayload1 = "request1"
val requestPayload2 = "request2"
val responsePayload1 = "response1"
val responsePayload2 = "response2"
rsocket.setPayloadFluxToReturn(
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
val response = service.requestChannel(flowOf(requestPayload1, requestPayload2)).toList()

assertThat(response).containsExactly(responsePayload1, responsePayload2)
assertThat(rsocket.savedMethodName).isEqualTo("requestChannel")

val savedPayloads = rsocket.savedPayloadFlux
?.asFlow()
?.map { it.dataUtf8 }
?.toList()
assertThat(savedPayloads).containsExactly(requestPayload1, requestPayload2)
}

private interface SuspendingFunctionsService {

@RSocketExchange("ff")
suspend fun fireAndForget(input: String)

@RSocketExchange("rr")
suspend fun requestResponse(input: String): String

@RSocketExchange("rs")
suspend fun requestStream(input: String): Flow<String>

@RSocketExchange("rc")
suspend fun requestChannel(input: Flow<String>): Flow<String>
}
}