Skip to content

Commit 5b87f3b

Browse files
committed
Implement token refresh support class
1 parent 780e64b commit 5b87f3b

File tree

6 files changed

+217
-43
lines changed

6 files changed

+217
-43
lines changed

src/main/java/com/rabbitmq/client/amqp/impl/TokenCredentials.java

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@ private void unlock() {
5656
}
5757

5858
private boolean expiresSoon(Token t) {
59+
// TODO use strategy to tell if the token expires soon
5960
return t.expirationTime() < System.currentTimeMillis() - 20_000;
6061
}
6162

6263
private Duration delayBeforeTokenRenewal(Token token) {
6364
long expiresIn = token.expirationTime() - System.currentTimeMillis();
65+
// TODO use strategy to decide when to renew token
6466
long delay = (long) (expiresIn * 0.8);
6567
return Duration.ofMillis(delay);
6668
}
@@ -77,31 +79,43 @@ public Registration register(RefreshCallback refreshCallback) {
7779
return registration;
7880
}
7981

80-
private void refresh() {
82+
private void refreshRegistrations(Token t) {
8183
this.scheduledExecutorService.execute(
8284
() -> {
8385
for (RegistrationImpl registration : this.registrations.values()) {
84-
if (!registration.isClosed() && !this.token.equals(registration.registrationToken)) {
85-
// the registration does not have the new token yet
86-
registration.refreshCallback.refresh("", this.token.value());
87-
registration.registrationToken = this.token;
86+
if (t.equals(this.token)) {
87+
if (!registration.isClosed() && !registration.hasSameToken(t)) {
88+
// the registration does not have the new token yet
89+
registration.refreshCallback.refresh("", this.token.value());
90+
registration.registrationToken = this.token;
91+
}
8892
}
8993
}
9094
});
9195
}
9296

9397
private void token(Token t) {
94-
if (!t.equals(this.token)) {
95-
this.token = t;
96-
if (this.schedulingRenewal.compareAndSet(false, true)) {
97-
if (this.renewalTask != null) {
98-
this.renewalTask.cancel(false);
99-
}
100-
Duration delay = delayBeforeTokenRenewal(t);
101-
if (delay.isZero() || delay.isNegative()) {
102-
delay = Duration.ofSeconds(1);
103-
}
104-
// TODO check delay is > 0, schedule 1 second later at least
98+
lock();
99+
try {
100+
if (!t.equals(this.token)) {
101+
this.token = t;
102+
scheduleRenewal(t);
103+
}
104+
} finally {
105+
unlock();
106+
}
107+
}
108+
109+
private void scheduleRenewal(Token t) {
110+
if (this.schedulingRenewal.compareAndSet(false, true)) {
111+
if (this.renewalTask != null) {
112+
this.renewalTask.cancel(false);
113+
}
114+
Duration delay = delayBeforeTokenRenewal(t);
115+
if (delay.isZero() || delay.isNegative()) {
116+
delay = Duration.ofSeconds(1);
117+
}
118+
if (!this.registrations.isEmpty()) {
105119
this.renewalTask =
106120
this.scheduledExecutorService.schedule(
107121
() -> {
@@ -111,16 +125,18 @@ private void token(Token t) {
111125
if (this.token.equals(previousToken)) {
112126
Token newToken = getToken();
113127
token(newToken);
114-
refresh();
128+
refreshRegistrations(newToken);
115129
}
116130
} finally {
117131
unlock();
118132
}
119133
},
120134
delay.toMillis(),
121135
TimeUnit.MILLISECONDS);
122-
this.schedulingRenewal.set(false);
136+
} else {
137+
this.renewalTask = null;
123138
}
139+
this.schedulingRenewal.set(false);
124140
}
125141
}
126142

@@ -139,6 +155,7 @@ private RegistrationImpl(Long id, RefreshCallback refreshCallback) {
139155
@Override
140156
public void connect(ConnectionCallback callback) {
141157
boolean shouldRefresh = false;
158+
Token tokenToUse;
142159
lock();
143160
try {
144161
if (token == null) {
@@ -148,23 +165,42 @@ public void connect(ConnectionCallback callback) {
148165
token(getToken());
149166
}
150167
this.registrationToken = token;
168+
tokenToUse = this.registrationToken;
169+
if (renewalTask == null) {
170+
scheduleRenewal(tokenToUse);
171+
}
151172
} finally {
152173
unlock();
153174
}
154175

155-
callback.username("").password(this.registrationToken.value());
176+
callback.username("").password(tokenToUse.value());
156177
if (shouldRefresh) {
157-
refresh();
178+
refreshRegistrations(tokenToUse);
158179
}
159180
}
160181

161182
@Override
162183
public void unregister() {
163184
if (this.closed.compareAndSet(false, true)) {
164185
registrations.remove(this.id);
186+
ScheduledFuture<?> task = renewalTask;
187+
if (registrations.isEmpty() && task != null) {
188+
lock();
189+
try {
190+
if (renewalTask != null) {
191+
renewalTask.cancel(false);
192+
}
193+
} finally {
194+
unlock();
195+
}
196+
}
165197
}
166198
}
167199

200+
private boolean hasSameToken(Token t) {
201+
return t.equals(this.registrationToken);
202+
}
203+
168204
private boolean isClosed() {
169205
return this.closed.get();
170206
}

src/test/java/com/rabbitmq/client/amqp/impl/Cli.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ void pump() {
508508
String line;
509509
while (true) {
510510
try {
511-
if (!((line = reader.readLine()) != null)) break;
511+
if ((line = reader.readLine()) == null) break;
512512
} catch (IOException e) {
513513
throw new RuntimeException(e);
514514
}

src/test/java/com/rabbitmq/client/amqp/impl/ConsumerOutcomeTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ void discardedMessageShouldBeDeadLeadLetteredWhenConfigured() {
180180
}
181181

182182
@Test
183-
void discardedMessageWithAnnotationsShouldBeDeadLeadLetteredAndContainAnnotationsWhenConfigured() {
183+
void
184+
discardedMessageWithAnnotationsShouldBeDeadLeadLetteredAndContainAnnotationsWhenConfigured() {
184185
declareDeadLetterTopology();
185186
Publisher publisher = this.connection.publisherBuilder().queue(q).build();
186187
this.connection

src/test/java/com/rabbitmq/client/amqp/impl/TestUtils.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
public abstract class TestUtils {
5454

5555
static final Duration DEFAULT_CONDITION_TIMEOUT = Duration.ofSeconds(10);
56+
static final Duration DEFAULT_WAIT_TIME = Duration.ofMillis(100);
5657

5758
private TestUtils() {}
5859

@@ -113,6 +114,7 @@ public static Duration waitAtMostNoException(RunnableWithException condition) {
113114
public static Duration waitAtMostNoException(Duration timeout, RunnableWithException condition) {
114115
return waitAtMost(
115116
timeout,
117+
DEFAULT_WAIT_TIME,
116118
() -> {
117119
try {
118120
condition.run();
@@ -125,15 +127,20 @@ public static Duration waitAtMostNoException(Duration timeout, RunnableWithExcep
125127
}
126128

127129
public static Duration waitAtMost(CallableBooleanSupplier condition) {
128-
return waitAtMost(DEFAULT_CONDITION_TIMEOUT, condition, null);
130+
return waitAtMost(DEFAULT_CONDITION_TIMEOUT, DEFAULT_WAIT_TIME, condition, null);
129131
}
130132

131133
public static Duration waitAtMost(CallableBooleanSupplier condition, Supplier<String> message) {
132-
return waitAtMost(DEFAULT_CONDITION_TIMEOUT, condition, message);
134+
return waitAtMost(DEFAULT_CONDITION_TIMEOUT, DEFAULT_WAIT_TIME, condition, message);
135+
}
136+
137+
public static Duration waitAtMost(
138+
Duration timeout, Duration waitTime, CallableBooleanSupplier condition) {
139+
return waitAtMost(timeout, waitTime, condition, null);
133140
}
134141

135142
public static Duration waitAtMost(Duration timeout, CallableBooleanSupplier condition) {
136-
return waitAtMost(timeout, condition, null);
143+
return waitAtMost(timeout, DEFAULT_WAIT_TIME, condition, null);
137144
}
138145

139146
public static Duration waitAtMost(int timeoutInSeconds, CallableBooleanSupplier condition) {
@@ -142,17 +149,24 @@ public static Duration waitAtMost(int timeoutInSeconds, CallableBooleanSupplier
142149

143150
public static Duration waitAtMost(
144151
int timeoutInSeconds, CallableBooleanSupplier condition, Supplier<String> message) {
145-
return waitAtMost(Duration.ofSeconds(timeoutInSeconds), condition, message);
152+
return waitAtMost(Duration.ofSeconds(timeoutInSeconds), DEFAULT_WAIT_TIME, condition, message);
146153
}
147154

148155
public static Duration waitAtMost(
149156
Duration timeout, CallableBooleanSupplier condition, Supplier<String> message) {
157+
return waitAtMost(timeout, DEFAULT_WAIT_TIME, condition, message);
158+
}
159+
160+
public static Duration waitAtMost(
161+
Duration timeout,
162+
Duration waitTime,
163+
CallableBooleanSupplier condition,
164+
Supplier<String> message) {
150165
long start = System.nanoTime();
151166
try {
152167
if (condition.getAsBoolean()) {
153168
return Duration.ZERO;
154169
}
155-
Duration waitTime = Duration.ofMillis(100);
156170
Duration waitedTime = Duration.ofNanos(System.nanoTime() - start);
157171
Exception exception = null;
158172
while (waitedTime.compareTo(timeout) <= 0) {
@@ -429,7 +443,6 @@ private static class DisabledIfOauth2AuthBackendNotEnabledCondition
429443
static class DisabledIfNotClusterCondition implements ExecutionCondition {
430444

431445
private static final String KEY = "isCluster";
432-
private static final String KEY_NODES = "clusterNodes";
433446

434447
@Override
435448
public ConditionEvaluationResult evaluateExecutionCondition(ExtensionContext context) {

src/test/java/com/rabbitmq/client/amqp/impl/TokenCredentialsTest.java

Lines changed: 93 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,21 @@
1818
package com.rabbitmq.client.amqp.impl;
1919

2020
import static com.rabbitmq.client.amqp.impl.Assertions.assertThat;
21+
import static com.rabbitmq.client.amqp.impl.TestUtils.sync;
22+
import static com.rabbitmq.client.amqp.impl.TestUtils.waitAtMost;
23+
import static java.util.stream.Collectors.toList;
24+
import static java.util.stream.IntStream.range;
25+
import static org.assertj.core.api.Assertions.assertThat;
2126
import static org.mockito.Mockito.when;
2227

2328
import com.rabbitmq.client.amqp.impl.TestUtils.Sync;
2429
import com.rabbitmq.client.amqp.oauth.Token;
2530
import com.rabbitmq.client.amqp.oauth.TokenRequester;
31+
import java.time.Duration;
32+
import java.util.List;
2633
import java.util.concurrent.Executors;
2734
import java.util.concurrent.ScheduledExecutorService;
35+
import java.util.concurrent.atomic.AtomicInteger;
2836
import org.junit.jupiter.api.AfterEach;
2937
import org.junit.jupiter.api.BeforeEach;
3038
import org.junit.jupiter.api.Test;
@@ -50,31 +58,86 @@ void tearDown() throws Exception {
5058
}
5159

5260
@Test
53-
void refresh() {
61+
void refreshShouldStopOnceUnregistered() throws InterruptedException {
62+
Duration tokenExpiry = Duration.ofMillis(50);
63+
AtomicInteger requestCount = new AtomicInteger(0);
5464
when(this.requester.request())
55-
.thenAnswer(ignored -> token("ok", System.currentTimeMillis() + 100));
65+
.thenAnswer(
66+
ignored -> {
67+
requestCount.incrementAndGet();
68+
return token("ok", System.currentTimeMillis() + tokenExpiry.toMillis());
69+
});
5670
TokenCredentials credentials =
5771
new TokenCredentials(this.requester, this.scheduledExecutorService);
58-
Sync refreshSync = TestUtils.sync(3);
72+
int expectedRefreshCount = 3;
73+
AtomicInteger refreshCount = new AtomicInteger();
74+
Sync refreshSync = sync(expectedRefreshCount);
5975
Credentials.Registration registration =
6076
credentials.register(
6177
(u, p) -> {
78+
refreshCount.incrementAndGet();
6279
refreshSync.down();
6380
});
64-
registration.connect(
65-
new Credentials.ConnectionCallback() {
66-
@Override
67-
public Credentials.ConnectionCallback username(String username) {
68-
return this;
69-
}
70-
71-
@Override
72-
public Credentials.ConnectionCallback password(String password) {
73-
return this;
74-
}
75-
});
81+
registration.connect(connectionCallback(() -> {}));
82+
assertThat(requestCount).hasValue(1);
7683
assertThat(refreshSync).completes();
84+
assertThat(requestCount).hasValue(expectedRefreshCount + 1);
7785
registration.unregister();
86+
assertThat(refreshCount).hasValue(expectedRefreshCount);
87+
assertThat(requestCount).hasValue(expectedRefreshCount + 1);
88+
Thread.sleep(tokenExpiry.multipliedBy(2).toMillis());
89+
assertThat(refreshCount).hasValue(expectedRefreshCount);
90+
assertThat(requestCount).hasValue(expectedRefreshCount + 1);
91+
}
92+
93+
@Test
94+
void severalRegistrationsShouldBeRefreshed() throws InterruptedException {
95+
Duration tokenExpiry = Duration.ofMillis(50);
96+
Duration waitTime = tokenExpiry.dividedBy(2);
97+
Duration timeout = tokenExpiry.multipliedBy(10);
98+
when(this.requester.request())
99+
.thenAnswer(ignored -> token("ok", System.currentTimeMillis() + tokenExpiry.toMillis()));
100+
TokenCredentials credentials =
101+
new TokenCredentials(this.requester, this.scheduledExecutorService);
102+
int expectedRefreshCountPerConnection = 3;
103+
int connectionCount = 10;
104+
AtomicInteger totalRefreshCount = new AtomicInteger();
105+
List<Tuples.Pair<Credentials.Registration, Sync>> registrations =
106+
range(0, connectionCount)
107+
.mapToObj(
108+
ignored -> {
109+
Sync sync = sync(expectedRefreshCountPerConnection);
110+
Credentials.Registration r =
111+
credentials.register(
112+
(username, password) -> {
113+
totalRefreshCount.incrementAndGet();
114+
sync.down();
115+
});
116+
return Tuples.pair(r, sync);
117+
})
118+
.collect(toList());
119+
120+
registrations.forEach(r -> r.v1().connect(connectionCallback(() -> {})));
121+
registrations.forEach(r -> assertThat(r.v2()).completes());
122+
// all connections have been refreshed once
123+
int refreshCountSnapshot = totalRefreshCount.get();
124+
assertThat(refreshCountSnapshot).isEqualTo(connectionCount * expectedRefreshCountPerConnection);
125+
126+
// unregister half of the connections
127+
int splitCount = connectionCount / 2;
128+
registrations.subList(0, splitCount).forEach(r -> r.v1().unregister());
129+
// only the remaining connections should get refreshed again
130+
waitAtMost(
131+
timeout, waitTime, () -> totalRefreshCount.get() == refreshCountSnapshot + splitCount);
132+
// waiting another round of refresh
133+
waitAtMost(
134+
timeout, waitTime, () -> totalRefreshCount.get() == refreshCountSnapshot + splitCount * 2);
135+
// unregister all connections
136+
registrations.forEach(r -> r.v1().unregister());
137+
// wait 2 expiry times
138+
Thread.sleep(tokenExpiry.multipliedBy(2).toMillis());
139+
// no new refresh
140+
assertThat(totalRefreshCount).hasValue(refreshCountSnapshot + splitCount * 2);
78141
}
79142

80143
private static Token token(String value, long expirationTime) {
@@ -90,4 +153,19 @@ public long expirationTime() {
90153
}
91154
};
92155
}
156+
157+
private static Credentials.ConnectionCallback connectionCallback(Runnable passwordCallback) {
158+
return new Credentials.ConnectionCallback() {
159+
@Override
160+
public Credentials.ConnectionCallback username(String username) {
161+
return this;
162+
}
163+
164+
@Override
165+
public Credentials.ConnectionCallback password(String password) {
166+
passwordCallback.run();
167+
return this;
168+
}
169+
};
170+
}
93171
}

0 commit comments

Comments
 (0)