Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions spring-cloud-gateway-integration-tests/grpc/pom.xml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<artifactId>grpc</artifactId>
Expand All @@ -22,6 +22,18 @@
<relativePath>..</relativePath> <!-- lookup parent from repository -->
</parent>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-bom</artifactId>
<version>2025.0.0-SNAPSHOT</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down Expand Up @@ -113,12 +125,15 @@
<artifactId>protobuf-maven-plugin</artifactId>
<version>0.6.1</version>
<configuration>
<protocArtifact>com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}</protocArtifact>
<protocArtifact>
com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}</protocArtifact>
<pluginId>grpc-java</pluginId>
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
<pluginArtifact>
io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
</configuration>
<executions>
<execution>
<?m2e execute onConfiguration,onIncremental?>
<goals>
<goal>compile</goal>
<goal>compile-custom</goal>
Expand All @@ -128,5 +143,4 @@
</plugin>
</plugins>
</build>
</project>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ private void start() throws IOException {
Integer serverPort = environment.getProperty("local.server.port", Integer.class);
int grpcPort = serverPort + 1;
ServerCredentials creds = createServerCredentials();
server = Grpc.newServerBuilderForPort(grpcPort, creds).addService(new HelloService()).build().start();
server = Grpc.newServerBuilderForPort(grpcPort, creds)
.addService(new HelloService())
.addService(new StreamService())
.build()
.start();

log.info("Starting gRPC server in port " + grpcPort);

Expand Down Expand Up @@ -101,6 +105,38 @@ private void stop() throws InterruptedException {
log.info("gRPC server stopped");
}

static class StreamService extends StreamServiceGrpc.StreamServiceImplBase {

@Override
public void more(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
int count = 0;
while (count < 3) {
HelloResponse reply = HelloResponse.newBuilder()
.setGreeting("Hello(" + count + ") ==> " + request.getFirstName())
.build();
if ("failWithRuntimeExceptionAfterData!".equals(request.getFirstName()) && count == 2) {
StatusRuntimeException exception = Status.RESOURCE_EXHAUSTED
.withDescription("Too long firstNames?")
.asRuntimeException();
responseObserver.onError(exception);
return;
}
responseObserver.onNext(reply);
count++;
try {
Thread.sleep(200L);
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
responseObserver.onError(e);
return;
}
}
responseObserver.onCompleted();
}

}

static class HelloService extends HelloServiceGrpc.HelloServiceImplBase {

@Override
Expand All @@ -119,6 +155,14 @@ public void hello(HelloRequest request, StreamObserver<HelloResponse> responseOb
HelloResponse response = HelloResponse.newBuilder().setGreeting(greeting).build();

responseObserver.onNext(response);

if ("failWithRuntimeExceptionAfterData!".equals(request.getFirstName())) {
StatusRuntimeException exception = Status.RESOURCE_EXHAUSTED.withDescription("Too long firstNames?")
.asRuntimeException();
responseObserver.onError(exception);
return;
}

responseObserver.onCompleted();
}

Expand Down
Binary file modified spring-cloud-gateway-integration-tests/grpc/src/main/proto/hello.pb
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
syntax = "proto3";
option java_multiple_files = true;
option java_package = "org.springframework.cloud.gateway.tests.grpc";
package org.springframework.cloud.gateway.tests.grpc;

import "hello.proto";

service StreamService {
rpc more(HelloRequest) returns (stream HelloResponse);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.cloud.gateway.tests.grpc;

import java.security.cert.X509Certificate;
import java.util.Iterator;

import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManager;
Expand All @@ -30,23 +31,31 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
import org.springframework.boot.web.server.test.LocalServerPort;
import org.springframework.test.annotation.DirtiesContext;

import static io.grpc.Status.FAILED_PRECONDITION;
import static io.grpc.Status.RESOURCE_EXHAUSTED;
import static io.grpc.netty.NegotiationType.TLS;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment;

/**
* @author Alberto C. Ríos
*/
@SpringBootTest(classes = org.springframework.cloud.gateway.tests.grpc.GRPCApplication.class,
webEnvironment = WebEnvironment.RANDOM_PORT)
@DirtiesContext
public class GRPCApplicationTests {

@LocalServerPort
private int gatewayPort;

public static void main(String[] args) {
SpringApplication.run(GRPCApplication.class, args);
}

@BeforeEach
void setUp() {
int grpcServerPort = gatewayPort + 1;
Expand All @@ -64,6 +73,18 @@ public void gRPCUnaryCallShouldReturnResponse() throws SSLException {
Assertions.assertThat(response.getGreeting()).isEqualTo("Hello, Sir FromClient");
}

@Test
public void gRPCStreamingCallShouldReturnResponse() throws SSLException {
ManagedChannel channel = createSecuredChannel(gatewayPort);

final Iterator<HelloResponse> response = StreamServiceGrpc.newBlockingStub(channel)
.more(HelloRequest.newBuilder().setFirstName("Sir").setLastName("FromClient").build());

Assertions.assertThat(response.next().getGreeting()).isEqualTo("Hello(0) ==> Sir");
Assertions.assertThat(response.next().getGreeting()).isEqualTo("Hello(1) ==> Sir");
Assertions.assertThat(response.next().getGreeting()).isEqualTo("Hello(2) ==> Sir");
}

private ManagedChannel createSecuredChannel(int port) throws SSLException {
TrustManager[] trustAllCerts = createTrustAllTrustManager();

Expand All @@ -88,6 +109,48 @@ public void gRPCUnaryCallShouldHandleRuntimeException() throws SSLException {
}
}

@Test
public void gRPCUnaryCallShouldHandleRuntimeExceptionAfterData() throws SSLException {
ManagedChannel channel = createSecuredChannel(gatewayPort);
boolean thrown = false;
try {
HelloServiceGrpc.newBlockingStub(channel)
.hello(HelloRequest.newBuilder().setFirstName("failWithRuntimeExceptionAfterData!").build())
.getGreeting();
}
catch (StatusRuntimeException e) {
thrown = true;
Assertions.assertThat(e.getStatus().getCode()).isEqualTo(RESOURCE_EXHAUSTED.getCode());
Assertions.assertThat(e.getStatus().getDescription()).isEqualTo("Too long firstNames?");
}
Assertions.assertThat(thrown).withFailMessage("Expected exception not thrown!").isTrue();
}

@Test
public void gRPCStreamingCallShouldHandleRuntimeExceptionAfterData() throws SSLException {
ManagedChannel channel = createSecuredChannel(gatewayPort);
boolean thrown = false;
final Iterator<HelloResponse> response = StreamServiceGrpc.newBlockingStub(channel)
.more(HelloRequest.newBuilder()
.setFirstName("failWithRuntimeExceptionAfterData!")
.setLastName("FromClient")
.build());
Assertions.assertThat(response.next().getGreeting())
.isEqualTo("Hello(0) ==> failWithRuntimeExceptionAfterData!");
Assertions.assertThat(response.next().getGreeting())
.isEqualTo("Hello(1) ==> failWithRuntimeExceptionAfterData!");
try {
Assertions.assertThat(response.next().getGreeting())
.isEqualTo("Hello(2) ==> failWithRuntimeExceptionAfterData!");
}
catch (StatusRuntimeException e) {
thrown = true;
Assertions.assertThat(e.getStatus().getCode()).isEqualTo(RESOURCE_EXHAUSTED.getCode());
Assertions.assertThat(e.getStatus().getDescription()).isEqualTo("Too long firstNames?");
}
Assertions.assertThat(thrown).withFailMessage("Expected exception not thrown!").isTrue();
}

private TrustManager[] createTrustAllTrustManager() {
return new TrustManager[] { new X509TrustManager() {
public X509Certificate[] getAcceptedIssuers() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
import org.apache.hc.core5.ssl.TrustStrategy;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.web.server.test.LocalServerPort;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.web.client.RestTemplate;

import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
Expand All @@ -50,8 +50,8 @@
* @author Alberto C. Ríos
* @author Abel Salgado Romero
*/
@Disabled
@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
@DirtiesContext
public class JsonToGrpcApplicationTests {

@LocalServerPort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ server:
management:
endpoint:
health:
show-details: when_authorized
show-details: when-authorized
gateway:
enabled: true
endpoints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -828,8 +828,9 @@ public NettyRoutingFilter routingFilter(HttpClient httpClient,

@Bean
@ConditionalOnEnabledGlobalFilter(NettyRoutingFilter.class)
public NettyWriteResponseFilter nettyWriteResponseFilter(GatewayProperties properties) {
return new NettyWriteResponseFilter(properties.getStreamingMediaTypes());
public NettyWriteResponseFilter nettyWriteResponseFilter(GatewayProperties properties,
ObjectProvider<List<HttpHeadersFilter>> headersFilters) {
return new NettyWriteResponseFilter(properties.getStreamingMediaTypes(), headersFilters);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;
import reactor.netty.Connection;
import reactor.netty.http.client.HttpClientResponse;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter;
import org.springframework.cloud.gateway.filter.headers.TrailerHeadersFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
Expand All @@ -36,6 +40,7 @@
import org.springframework.lang.Nullable;
import org.springframework.web.server.ServerWebExchange;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CLIENT_RESPONSE_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CLIENT_RESPONSE_CONN_ATTR;

/**
Expand All @@ -52,8 +57,22 @@ public class NettyWriteResponseFilter implements GlobalFilter, Ordered {

private final List<MediaType> streamingMediaTypes;

public NettyWriteResponseFilter(List<MediaType> streamingMediaTypes) {
private final ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider;

// do not use this headersFilters directly, use getHeadersFilters() instead.
private volatile List<HttpHeadersFilter> headersFilters;

public NettyWriteResponseFilter(List<MediaType> streamingMediaTypes,
ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider) {
this.streamingMediaTypes = streamingMediaTypes;
this.headersFiltersProvider = headersFiltersProvider;
}

public List<HttpHeadersFilter> getHeadersFilters() {
if (headersFilters == null) {
headersFilters = headersFiltersProvider == null ? List.of() : headersFiltersProvider.getIfAvailable();
}
return headersFilters;
}

@Override
Expand Down Expand Up @@ -96,9 +115,12 @@ public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
log.trace("invalid media type", e);
}
}
return (isStreamingMediaType(contentType)

HttpClientResponse httpClientResponse = exchange.getAttribute(CLIENT_RESPONSE_ATTR);
Mono<Void> write = (isStreamingMediaType(contentType)
? response.writeAndFlushWith(body.map(Flux::just))
: response.writeWith(body));
return write.then(TrailerHeadersFilter.filter(getHeadersFilters(), exchange, httpClientResponse)).then();
}))
.doFinally(signalType -> {
if (signalType == SignalType.CANCEL || signalType == SignalType.ON_ERROR) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,20 @@ public class GRPCResponseHeadersFilter implements HttpHeadersFilter, Ordered {
@Override
public HttpHeaders filter(HttpHeaders headers, ServerWebExchange exchange) {
ServerHttpResponse response = exchange.getResponse();
HttpHeaders responseHeaders = response.getHeaders();
if (isGRPC(exchange)) {
String trailerHeaderValue = GRPC_STATUS_HEADER + "," + GRPC_MESSAGE_HEADER;
String originalTrailerHeaderValue = responseHeaders.getFirst(HttpHeaders.TRAILER);
if (originalTrailerHeaderValue != null) {
trailerHeaderValue += "," + originalTrailerHeaderValue;
}
responseHeaders.set(HttpHeaders.TRAILER, trailerHeaderValue);

while (response instanceof ServerHttpResponseDecorator) {
response = ((ServerHttpResponseDecorator) response).getDelegate();
}
if (response instanceof AbstractServerHttpResponse) {
String grpcStatus = getGrpcStatus(headers);
String grpcMessage = getGrpcMessage(headers);
((HttpServerResponse) ((AbstractServerHttpResponse) response).getNativeResponse()).trailerHeaders(h -> {
h.set(GRPC_STATUS_HEADER, grpcStatus);
h.set(GRPC_MESSAGE_HEADER, grpcMessage);
});
String grpcStatus = getGrpcStatus(headers);
String grpcMessage = getGrpcMessage(headers);
if (grpcStatus != null) {
while (response instanceof ServerHttpResponseDecorator) {
response = ((ServerHttpResponseDecorator) response).getDelegate();
}
if (response instanceof AbstractServerHttpResponse) {
((HttpServerResponse) ((AbstractServerHttpResponse) response).getNativeResponse())
.trailerHeaders(h -> {
h.set(GRPC_STATUS_HEADER, grpcStatus);
h.set(GRPC_MESSAGE_HEADER, grpcMessage);
});
}
}

}
Expand All @@ -70,7 +65,7 @@ private boolean isGRPC(ServerWebExchange exchange) {

private String getGrpcStatus(HttpHeaders headers) {
final String grpcStatusValue = headers.getFirst(GRPC_STATUS_HEADER);
return StringUtils.hasText(grpcStatusValue) ? grpcStatusValue : "0";
return StringUtils.hasText(grpcStatusValue) ? grpcStatusValue : null;
}

private String getGrpcMessage(HttpHeaders headers) {
Expand Down
Loading
Loading