Skip to content

Commit 1a4844f

Browse files
authored
Merge pull request #529 from rsocket/feature/keep-alives-reimplementation
Reimplementation of keep-alives according to spec
2 parents d83ffdd + eba3202 commit 1a4844f

File tree

11 files changed

+430
-110
lines changed

11 files changed

+430
-110
lines changed

build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ subprojects {
5050
dependencySet(group: 'org.junit.jupiter', version: '5.1.0') {
5151
entry 'junit-jupiter-api'
5252
entry 'junit-jupiter-engine'
53+
entry 'junit-jupiter-params'
5354
}
5455

5556
// TODO: Remove after JUnit5 migration

rsocket-core/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies {
3636
testImplementation 'io.projectreactor:reactor-test'
3737
testImplementation 'org.assertj:assertj-core'
3838
testImplementation 'org.junit.jupiter:junit-jupiter-api'
39+
testImplementation 'org.junit.jupiter:junit-jupiter-params'
3940
testImplementation 'org.mockito:mockito-core'
4041

4142
testRuntimeOnly 'ch.qos.logback:logback-classic'

rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ public static ConnectionSetupPayload create(final Frame setupFrame) {
3434
return new DefaultConnectionSetupPayload(setupFrame);
3535
}
3636

37+
public abstract int keepAliveInterval();
38+
39+
public abstract int keepAliveMaxLifetime();
40+
3741
public abstract String metadataMimeType();
3842

3943
public abstract String dataMimeType();
@@ -73,6 +77,16 @@ public DefaultConnectionSetupPayload(final Frame setupFrame) {
7377
this.setupFrame = setupFrame;
7478
}
7579

80+
@Override
81+
public int keepAliveInterval() {
82+
return SetupFrameFlyweight.keepaliveInterval(setupFrame.content());
83+
}
84+
85+
@Override
86+
public int keepAliveMaxLifetime() {
87+
return SetupFrameFlyweight.maxLifetime(setupFrame.content());
88+
}
89+
7690
@Override
7791
public String metadataMimeType() {
7892
return Setup.metadataMimeType(setupFrame);
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package io.rsocket;
2+
3+
import io.netty.buffer.Unpooled;
4+
import java.time.Duration;
5+
import reactor.core.Disposable;
6+
import reactor.core.publisher.Flux;
7+
import reactor.core.publisher.Mono;
8+
import reactor.core.publisher.MonoProcessor;
9+
import reactor.core.publisher.UnicastProcessor;
10+
11+
abstract class KeepAliveHandler {
12+
private final KeepAlive keepAlive;
13+
private final UnicastProcessor<Frame> sent = UnicastProcessor.create();
14+
private final MonoProcessor<KeepAlive> timeout = MonoProcessor.create();
15+
private final Flux<Long> interval;
16+
private Disposable intervalDisposable;
17+
private volatile long lastReceivedMillis;
18+
19+
static KeepAliveHandler ofServer(KeepAlive keepAlive) {
20+
return new KeepAliveHandler.Server(keepAlive);
21+
}
22+
23+
static KeepAliveHandler ofClient(KeepAlive keepAlive) {
24+
return new KeepAliveHandler.Client(keepAlive);
25+
}
26+
27+
private KeepAliveHandler(KeepAlive keepAlive) {
28+
this.keepAlive = keepAlive;
29+
this.interval = Flux.interval(Duration.ofMillis(keepAlive.getTickPeriod()));
30+
}
31+
32+
public void start() {
33+
this.lastReceivedMillis = System.currentTimeMillis();
34+
intervalDisposable = interval.subscribe(v -> onIntervalTick());
35+
}
36+
37+
public void stop() {
38+
sent.onComplete();
39+
timeout.onComplete();
40+
if (intervalDisposable != null) {
41+
intervalDisposable.dispose();
42+
}
43+
}
44+
45+
public void receive(Frame keepAliveFrame) {
46+
this.lastReceivedMillis = System.currentTimeMillis();
47+
if (Frame.Keepalive.hasRespondFlag(keepAliveFrame)) {
48+
doSend(Frame.Keepalive.from(Unpooled.wrappedBuffer(keepAliveFrame.getData()), false));
49+
}
50+
}
51+
52+
public Flux<Frame> send() {
53+
return sent;
54+
}
55+
56+
public Mono<KeepAlive> timeout() {
57+
return timeout;
58+
}
59+
60+
abstract void onIntervalTick();
61+
62+
void doSend(Frame frame) {
63+
sent.onNext(frame);
64+
}
65+
66+
void doCheckTimeout() {
67+
long now = System.currentTimeMillis();
68+
if (now - lastReceivedMillis >= keepAlive.getTimeoutMillis()) {
69+
timeout.onNext(keepAlive);
70+
}
71+
}
72+
73+
private static class Server extends KeepAliveHandler {
74+
75+
Server(KeepAlive keepAlive) {
76+
super(keepAlive);
77+
}
78+
79+
@Override
80+
void onIntervalTick() {
81+
doCheckTimeout();
82+
}
83+
}
84+
85+
private static final class Client extends KeepAliveHandler {
86+
87+
Client(KeepAlive keepAlive) {
88+
super(keepAlive);
89+
}
90+
91+
@Override
92+
void onIntervalTick() {
93+
doCheckTimeout();
94+
doSend(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true));
95+
}
96+
}
97+
98+
static final class KeepAlive {
99+
private final long tickPeriod;
100+
private final long timeoutMillis;
101+
102+
KeepAlive(Duration tickPeriod, Duration timeoutMillis, int maxTicks) {
103+
this.tickPeriod = tickPeriod.toMillis();
104+
this.timeoutMillis = timeoutMillis.toMillis() + maxTicks * tickPeriod.toMillis();
105+
}
106+
107+
KeepAlive(long tickPeriod, long timeoutMillis) {
108+
this.tickPeriod = tickPeriod;
109+
this.timeoutMillis = timeoutMillis;
110+
}
111+
112+
public long getTickPeriod() {
113+
return tickPeriod;
114+
}
115+
116+
public long getTimeoutMillis() {
117+
return timeoutMillis;
118+
}
119+
}
120+
}

rsocket-core/src/main/java/io/rsocket/RSocketClient.java

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,18 @@
1616

1717
package io.rsocket;
1818

19-
import io.netty.buffer.Unpooled;
2019
import io.rsocket.exceptions.ConnectionErrorException;
2120
import io.rsocket.exceptions.Exceptions;
2221
import io.rsocket.framing.FrameType;
2322
import io.rsocket.internal.LimitableRequestPublisher;
2423
import io.rsocket.internal.UnboundedProcessor;
2524
import java.time.Duration;
2625
import java.util.concurrent.atomic.AtomicBoolean;
27-
import java.util.concurrent.atomic.AtomicInteger;
2826
import java.util.function.Consumer;
2927
import java.util.function.Function;
30-
import javax.annotation.Nullable;
3128
import org.jctools.maps.NonBlockingHashMapLong;
3229
import org.reactivestreams.Publisher;
3330
import org.reactivestreams.Subscriber;
34-
import reactor.core.Disposable;
3531
import reactor.core.publisher.*;
3632

3733
/** Client Side of a RSocket socket. Sends {@link Frame}s to a {@link RSocketServer} */
@@ -44,13 +40,10 @@ class RSocketClient implements RSocket {
4440
private final MonoProcessor<Void> started;
4541
private final NonBlockingHashMapLong<LimitableRequestPublisher> senders;
4642
private final NonBlockingHashMapLong<UnicastProcessor<Payload>> receivers;
47-
private final AtomicInteger missedAckCounter;
48-
4943
private final UnboundedProcessor<Frame> sendProcessor;
44+
private KeepAliveHandler keepAliveHandler;
5045

51-
private @Nullable Disposable keepAliveSendSub;
52-
private volatile long timeLastTickSentMs;
53-
46+
/*server requester*/
5447
RSocketClient(
5548
DuplexConnection connection,
5649
Function<Frame, ? extends Payload> frameDecoder,
@@ -59,7 +52,7 @@ class RSocketClient implements RSocket {
5952
this(
6053
connection, frameDecoder, errorConsumer, streamIdSupplier, Duration.ZERO, Duration.ZERO, 0);
6154
}
62-
55+
/*client requester*/
6356
RSocketClient(
6457
DuplexConnection connection,
6558
Function<Frame, ? extends Payload> frameDecoder,
@@ -75,24 +68,29 @@ class RSocketClient implements RSocket {
7568
this.started = MonoProcessor.create();
7669
this.senders = new NonBlockingHashMapLong<>(256);
7770
this.receivers = new NonBlockingHashMapLong<>(256);
78-
this.missedAckCounter = new AtomicInteger();
7971

8072
// DO NOT Change the order here. The Send processor must be subscribed to before receiving
8173
this.sendProcessor = new UnboundedProcessor<>();
8274

8375
if (!Duration.ZERO.equals(tickPeriod)) {
84-
long ackTimeoutMs = ackTimeout.toMillis();
85-
86-
this.keepAliveSendSub =
87-
started
88-
.thenMany(Flux.interval(tickPeriod))
89-
.doOnSubscribe(s -> timeLastTickSentMs = System.currentTimeMillis())
90-
.subscribe(
91-
i -> sendKeepAlive(ackTimeoutMs, missedAcks),
92-
t -> {
93-
errorConsumer.accept(t);
94-
connection.dispose();
95-
});
76+
this.keepAliveHandler =
77+
KeepAliveHandler.ofClient(
78+
new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout, missedAcks));
79+
80+
started.doOnTerminate(() -> keepAliveHandler.start()).subscribe();
81+
82+
keepAliveHandler
83+
.timeout()
84+
.subscribe(
85+
keepAlive -> {
86+
String message =
87+
String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis());
88+
errorConsumer.accept(new ConnectionErrorException(message));
89+
connection.dispose();
90+
});
91+
keepAliveHandler.send().subscribe(sendProcessor::onNext);
92+
} else {
93+
keepAliveHandler = null;
9694
}
9795

9896
connection.onClose().doFinally(signalType -> cleanup()).subscribe(null, errorConsumer);
@@ -140,22 +138,6 @@ private void handleSendProcessorCancel(SignalType t) {
140138
}
141139
}
142140

143-
private void sendKeepAlive(long ackTimeoutMs, int missedAcks) {
144-
long now = System.currentTimeMillis();
145-
if (now - timeLastTickSentMs > ackTimeoutMs) {
146-
int count = missedAckCounter.incrementAndGet();
147-
if (count >= missedAcks) {
148-
String message =
149-
String.format(
150-
"Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms",
151-
count, missedAcks, ackTimeoutMs);
152-
throw new ConnectionErrorException(message);
153-
}
154-
}
155-
156-
sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true));
157-
}
158-
159141
@Override
160142
public Mono<Void> fireAndForget(Payload payload) {
161143
Mono<Void> defer =
@@ -380,17 +362,16 @@ private boolean contains(int streamId) {
380362
}
381363

382364
protected void cleanup() {
365+
if (keepAliveHandler != null) {
366+
keepAliveHandler.stop();
367+
}
383368
try {
384369
for (UnicastProcessor<Payload> subscriber : receivers.values()) {
385370
cleanUpSubscriber(subscriber);
386371
}
387372
for (LimitableRequestPublisher p : senders.values()) {
388373
cleanUpLimitableRequestPublisher(p);
389374
}
390-
391-
if (null != keepAliveSendSub) {
392-
keepAliveSendSub.dispose();
393-
}
394375
} finally {
395376
senders.clear();
396377
receivers.clear();
@@ -437,8 +418,8 @@ private void handleStreamZero(FrameType type, Frame frame) {
437418
break;
438419
}
439420
case KEEPALIVE:
440-
if (!Frame.Keepalive.hasRespondFlag(frame)) {
441-
timeLastTickSentMs = System.currentTimeMillis();
421+
if (keepAliveHandler != null) {
422+
keepAliveHandler.receive(frame);
442423
}
443424
break;
444425
default:

rsocket-core/src/main/java/io/rsocket/RSocketFactory.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public static class ClientRSocketFactory implements ClientTransportAcceptor {
8787
private Payload setupPayload = EmptyPayload.INSTANCE;
8888
private Function<Frame, ? extends Payload> frameDecoder = DefaultPayload::create;
8989

90-
private Duration tickPeriod = Duration.ZERO;
90+
private Duration tickPeriod = Duration.ofSeconds(20);
9191
private Duration ackTimeout = Duration.ofSeconds(30);
9292
private int missedAcks = 3;
9393

@@ -109,8 +109,13 @@ public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) {
109109
return this;
110110
}
111111

112+
/**
113+
* Deprecated as Keep-Alive is not optional according to spec
114+
*
115+
* @return this ClientRSocketFactory
116+
*/
117+
@Deprecated
112118
public ClientRSocketFactory keepAlive() {
113-
tickPeriod = Duration.ofSeconds(20);
114119
return this;
115120
}
116121

@@ -205,8 +210,8 @@ public Mono<RSocket> start() {
205210
Frame setupFrame =
206211
Frame.Setup.from(
207212
flags,
208-
(int) ackTimeout.toMillis(),
209-
(int) ackTimeout.toMillis() * missedAcks,
213+
(int) tickPeriod.toMillis(),
214+
(int) (ackTimeout.toMillis() + tickPeriod.toMillis() * missedAcks),
210215
metadataMimeType,
211216
dataMimeType,
212217
setupPayload);
@@ -339,6 +344,8 @@ private Mono<Void> processSetupFrame(
339344
}
340345

341346
ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame);
347+
int keepAliveInterval = setupPayload.keepAliveInterval();
348+
int keepAliveMaxLifetime = setupPayload.keepAliveMaxLifetime();
342349

343350
RSocketClient rSocketClient =
344351
new RSocketClient(
@@ -361,7 +368,9 @@ private Mono<Void> processSetupFrame(
361368
multiplexer.asClientConnection(),
362369
wrappedRSocketServer,
363370
frameDecoder,
364-
errorConsumer);
371+
errorConsumer,
372+
keepAliveInterval,
373+
keepAliveMaxLifetime);
365374
})
366375
.doFinally(signalType -> setupPayload.release())
367376
.then();

0 commit comments

Comments
 (0)