Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ subprojects {
apply plugin: 'io.spring.dependency-management'
apply plugin: 'com.github.sherter.google-java-format'

ext['reactor-bom.version'] = 'Dysprosium-M3'
ext['reactor-bom.version'] = 'Dysprosium-BUILD-SNAPSHOT'
ext['logback.version'] = '1.2.3'
ext['findbugs.version'] = '3.0.2'
ext['netty.version'] = '4.1.37.Final'
Expand Down Expand Up @@ -97,9 +97,9 @@ subprojects {
mavenCentral()
maven { url 'http://repo.spring.io/milestone' } // temporary for Reactor Dysprosium

if (version.endsWith('BUILD-SNAPSHOT') || project.hasProperty('platformVersion')) {
// if (version.endsWith('BUILD-SNAPSHOT') || project.hasProperty('platformVersion')) {
maven { url 'http://repo.spring.io/libs-snapshot' }
}
// }
}

if (project.name != 'rsocket-bom') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,71 +16,69 @@

package io.rsocket.examples.transport.ws;

import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.rsocket.AbstractRSocket;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.DuplexConnection;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.RSocketFactory;
import io.rsocket.SocketAcceptor;
import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.transport.ServerTransport;
import io.rsocket.transport.netty.WebsocketDuplexConnection;
import io.rsocket.transport.netty.client.WebsocketClientTransport;
import io.rsocket.transport.netty.server.CloseableChannel;
import io.rsocket.transport.netty.server.WebsocketServerTransport;
import io.rsocket.util.ByteBufPayload;
import java.time.Duration;
import java.util.HashMap;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import reactor.core.scheduler.Schedulers;
import reactor.netty.Connection;
import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer;

public class WebSocketHeadersSample {
static final Payload payload1 = ByteBufPayload.create("Hello ");

public static void main(String[] args) {

ServerTransport.ConnectionAcceptor acceptor =
CloseableChannel disposableServer =
RSocketFactory.receive()
.frameDecoder(PayloadDecoder.ZERO_COPY)
.acceptor(new SocketAcceptorImpl())
.toConnectionAcceptor();

DisposableServer disposableServer =
HttpServer.create()
.host("localhost")
.port(0)
.route(
routes ->
routes.ws(
"/",
(in, out) -> {
if (in.headers().containsValue("Authorization", "test", true)) {
DuplexConnection connection =
new WebsocketDuplexConnection((Connection) in);
return acceptor.apply(connection).then(out.neverComplete());
}

return out.sendClose(
HttpResponseStatus.UNAUTHORIZED.code(),
HttpResponseStatus.UNAUTHORIZED.reasonPhrase());
}))
.bindNow();
.transport(
WebsocketServerTransport.builder()
.filteringInbound(
headers -> headers.containsValue("Authorization", "test", true))
.closingWithStatus(headers -> new WebSocketCloseStatus(4404, "Unauthorized"))
.build(HttpServer.create().host("localhost").port(8080))
// Same could be done with routing transport
// WebsocketRouteTransport
// .builder()
// .filteringInbound(headers ->
// headers.containsValue("Authorization", "test", true))
// .closingWithStatus(headers -> new WebSocketCloseStatus(4404,
// "Unauthorized"))
// .observingOn("/")
// .build(HttpServer.create().host("localhost").port(8080))
)
.start()
.block();

WebsocketClientTransport clientTransport =
WebsocketClientTransport.create(disposableServer.host(), disposableServer.port());
WebsocketClientTransport.create(disposableServer.address());

MonoProcessor<WebSocketCloseStatus> statusMonoProcessor = MonoProcessor.create();
clientTransport.setTransportHeaders(
() -> {
HashMap<String, String> map = new HashMap<>();
map.put("Authorization", "test");
map.put("Authorization", "1");
return map;
});

clientTransport.setCloseStatusConsumer(
webSocketCloseStatusMono -> webSocketCloseStatusMono.log().subscribe(statusMonoProcessor));

RSocket socket =
RSocketFactory.connect()
.keepAliveAckTimeout(Duration.ofMinutes(10))
Expand All @@ -89,17 +87,24 @@ public static void main(String[] args) {
.start()
.block();

Flux.range(0, 100)
.concatMap(i -> socket.fireAndForget(payload1.retain()))
// .doOnNext(p -> {
//// System.out.println(p.getDataUtf8());
// p.release();
// })
.blockLast();
try {
Flux.range(0, 100).concatMap(i -> socket.fireAndForget(payload1.retain())).blockLast();

} catch (Exception e) {
System.out.println("Observed WebSocket Close Status " + statusMonoProcessor.peek());
}

socket.dispose();

WebsocketClientTransport clientTransport2 =
WebsocketClientTransport.create(disposableServer.host(), disposableServer.port());
WebsocketClientTransport.create(disposableServer.address());

clientTransport2.setTransportHeaders(
() -> {
HashMap<String, String> map = new HashMap<>();
map.put("Authorization", "test");
return map;
});

RSocket rSocket =
RSocketFactory.connect()
Expand All @@ -109,7 +114,7 @@ public static void main(String[] args) {
.start()
.block();

// expect error here because of closed channel
// expect normal execution here
rSocket.requestResponse(payload1).block();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static io.rsocket.transport.netty.UriUtils.isSecure;

import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.rsocket.DuplexConnection;
import io.rsocket.fragmentation.FragmentationDuplexConnection;
import io.rsocket.transport.ClientTransport;
Expand All @@ -32,9 +33,11 @@
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Supplier;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.websocket.WebsocketInbound;
import reactor.netty.tcp.TcpClient;

/**
Expand All @@ -48,7 +51,9 @@ public final class WebsocketClientTransport implements ClientTransport, Transpor

private final HttpClient client;

private String path;
private final String path;

private Consumer<Mono<WebSocketCloseStatus>> closeStatusConsumer;

private Supplier<Map<String, String>> transportHeaders = Collections::emptyMap;

Expand Down Expand Up @@ -161,6 +166,11 @@ public Mono<DuplexConnection> connect(int mtu) {
.connect()
.map(
c -> {
Consumer<Mono<WebSocketCloseStatus>> closeStatusConsumer =
this.closeStatusConsumer;
if (closeStatusConsumer != null) {
closeStatusConsumer.accept(((WebsocketInbound) c).receiveCloseStatus());
}
DuplexConnection connection = new WebsocketDuplexConnection(c);
if (mtu > 0) {
connection =
Expand All @@ -176,4 +186,8 @@ public void setTransportHeaders(Supplier<Map<String, String>> transportHeaders)
this.transportHeaders =
Objects.requireNonNull(transportHeaders, "transportHeaders must not be null");
}

public void setCloseStatusConsumer(Consumer<Mono<WebSocketCloseStatus>> closeStatusConsumer) {
this.closeStatusConsumer = closeStatusConsumer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package io.rsocket.transport.netty.server;

import static io.netty.handler.codec.http.websocketx.WebSocketCloseStatus.NORMAL_CLOSURE;
import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;

import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.rsocket.Closeable;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.rsocket.DuplexConnection;
import io.rsocket.fragmentation.FragmentationDuplexConnection;
import io.rsocket.transport.ServerTransport;
Expand All @@ -32,8 +34,11 @@
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.annotation.Nullable;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
import reactor.netty.Connection;
Expand All @@ -46,14 +51,23 @@
* An implementation of {@link ServerTransport} that connects via Websocket and listens on specified
* routes.
*/
public final class WebsocketRouteTransport implements ServerTransport<Closeable> {
public final class WebsocketRouteTransport implements ServerTransport<CloseableChannel> {

private static final Function<HttpHeaders, WebSocketCloseStatus> DEFAULT_STATUS_SUPPLIER =
__ -> NORMAL_CLOSURE;

private static final Consumer<HttpServerRoutes> NO_OPS_ROUTES_BUILDER = (r) -> {};

private final UriPathTemplate template;

private final Consumer<? super HttpServerRoutes> routesBuilder;

private final HttpServer server;

private final Function<HttpHeaders, WebSocketCloseStatus> webSocketCloseStatusSupplier;

private final @Nullable Predicate<HttpHeaders> headersPredicate;

/**
* Creates a new instance
*
Expand All @@ -64,15 +78,43 @@ public final class WebsocketRouteTransport implements ServerTransport<Closeable>
public WebsocketRouteTransport(
HttpServer server, Consumer<? super HttpServerRoutes> routesBuilder, String path) {

this(server, routesBuilder, path, null, DEFAULT_STATUS_SUPPLIER);
}

public WebsocketRouteTransport(
HttpServer server,
Consumer<? super HttpServerRoutes> routesBuilder,
String path,
@Nullable Predicate<HttpHeaders> headersPredicate,
Function<HttpHeaders, WebSocketCloseStatus> webSocketCloseStatusSupplier) {

this.server = Objects.requireNonNull(server, "server must not be null");
this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null");
this.template = new UriPathTemplate(Objects.requireNonNull(path, "path must not be null"));
this.headersPredicate = headersPredicate;
this.webSocketCloseStatusSupplier =
Objects.requireNonNull(webSocketCloseStatusSupplier, "status supplier must not be null");
}

@Override
public Mono<Closeable> start(ConnectionAcceptor acceptor, int mtu) {
public Mono<CloseableChannel> start(ConnectionAcceptor acceptor, int mtu) {
Objects.requireNonNull(acceptor, "acceptor must not be null");

if (headersPredicate != null) {
return server
.route(
routes -> {
routesBuilder.accept(routes);
routes.ws(
hsr -> hsr.method().equals(HttpMethod.GET) && template.matches(hsr.uri()),
newHandler(acceptor, mtu, headersPredicate, webSocketCloseStatusSupplier),
null,
FRAME_LENGTH_MASK);
})
.bind()
.map(CloseableChannel::new);
}

return server
.route(
routes -> {
Expand Down Expand Up @@ -120,6 +162,36 @@ public static BiFunction<WebsocketInbound, WebsocketOutbound, Publisher<Void>> n
};
}

/**
* Creates a new Websocket handler
*
* @param acceptor the {@link ConnectionAcceptor} to use with the handler
* @param mtu the fragment size
* @return a new Websocket handler
* @throws NullPointerException if {@code acceptor} is {@code null}
*/
public static BiFunction<WebsocketInbound, WebsocketOutbound, Publisher<Void>> newHandler(
ConnectionAcceptor acceptor,
int mtu,
Predicate<HttpHeaders> headersPredicate,
Function<HttpHeaders, WebSocketCloseStatus> webSocketCloseStatusSupplier) {
return (in, out) -> {
HttpHeaders headers = in.headers();
if (!headersPredicate.test(headers)) {
final WebSocketCloseStatus status = webSocketCloseStatusSupplier.apply(headers);
return out.sendClose(status.code(), status.reasonText());
}

DuplexConnection connection = new WebsocketDuplexConnection((Connection) in);
if (mtu > 0) {
connection =
new FragmentationDuplexConnection(
connection, ByteBufAllocator.DEFAULT, mtu, false, "server");
}
return acceptor.apply(connection).then(out.neverComplete());
};
}

static final class UriPathTemplate {

private static final Pattern FULL_SPLAT_PATTERN = Pattern.compile("[\\*][\\*]");
Expand Down Expand Up @@ -236,4 +308,47 @@ private Matcher matcher(String uri) {
return m;
}
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private Predicate<HttpHeaders> headersPredicate;
private Function<HttpHeaders, WebSocketCloseStatus> webSocketCloseStatusSupplier =
DEFAULT_STATUS_SUPPLIER;
private String path = "/";
private Consumer<? super HttpServerRoutes> routesBuilder = NO_OPS_ROUTES_BUILDER;

public Builder filteringInbound(Predicate<HttpHeaders> headersPredicate) {
Objects.requireNonNull(headersPredicate, "Header predicate must not be null");
this.headersPredicate = headersPredicate;
return this;
}

public Builder closingWithStatus(
Function<HttpHeaders, WebSocketCloseStatus> webSocketCloseStatusSupplier) {
this.webSocketCloseStatusSupplier =
Objects.requireNonNull(
webSocketCloseStatusSupplier, "WebSocketCloseStatusSupplier must not be null");
return this;
}

public Builder observingOn(String path) {
Objects.requireNonNull(path, "path must not be null");
this.path = path;
return this;
}

public Builder routingWith(Consumer<? super HttpServerRoutes> routesBuilder) {
this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null");
return this;
}

public WebsocketRouteTransport build(HttpServer server) {
return new WebsocketRouteTransport(
server, routesBuilder, path, headersPredicate, webSocketCloseStatusSupplier);
}
}
}
Loading