Skip to content

Commit 311a313

Browse files
committed
Refactor OAuth 2 tests
1 parent 35a0fd5 commit 311a313

File tree

5 files changed

+199
-61
lines changed

5 files changed

+199
-61
lines changed

src/main/java/com/rabbitmq/stream/impl/StreamConsumer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class StreamConsumer implements Consumer {
6464
private final boolean sac;
6565
private final OffsetSpecification initialOffsetSpecification;
6666
private final Lock lock = new ReentrantLock();
67-
private volatile boolean consuming = false;
67+
private volatile boolean consuming;
6868

6969
@SuppressFBWarnings("CT_CONSTRUCTOR_THROW")
7070
StreamConsumer(

src/test/java/com/rabbitmq/stream/oauth2/HttpTokenRequesterTest.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
import com.google.gson.Gson;
2121
import com.google.gson.reflect.TypeToken;
22-
import com.rabbitmq.stream.impl.HttpTestUtils;
23-
import com.rabbitmq.stream.impl.TestUtils;
2422
import com.sun.net.httpserver.Headers;
2523
import com.sun.net.httpserver.HttpServer;
2624
import java.io.IOException;
@@ -50,7 +48,7 @@ public class HttpTokenRequesterTest {
5048

5149
@BeforeEach
5250
void init() throws IOException {
53-
this.port = TestUtils.randomNetworkPort();
51+
this.port = OAuth2TestUtils.randomNetworkPort();
5452
}
5553

5654
@ParameterizedTest
@@ -61,7 +59,7 @@ void requestToken(boolean tls) throws Exception {
6159
Consumer<HttpClient.Builder> clientBuilderConsumer;
6260
if (tls) {
6361
protocol = "https";
64-
keyStore = HttpTestUtils.generateKeyPair();
62+
keyStore = OAuth2TestUtils.generateKeyPair();
6563
SSLContext sslContext = SSLContext.getInstance("TLS");
6664
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
6765
tmf.init(keyStore);
@@ -83,7 +81,7 @@ void requestToken(boolean tls) throws Exception {
8381

8482
Duration expiresIn = Duration.ofSeconds(60);
8583
server =
86-
HttpTestUtils.startServer(
84+
OAuth2TestUtils.startServer(
8785
port,
8886
contextPath,
8987
keyStore,

src/test/java/com/rabbitmq/stream/oauth2/OAuth2TestUtils.java

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,40 @@
1414
1515
package com.rabbitmq.stream.oauth2;
1616

17+
import static org.junit.jupiter.api.Assertions.fail;
18+
19+
import com.sun.net.httpserver.HttpHandler;
20+
import com.sun.net.httpserver.HttpServer;
21+
import com.sun.net.httpserver.HttpsConfigurator;
22+
import com.sun.net.httpserver.HttpsServer;
23+
import java.io.IOException;
24+
import java.math.BigInteger;
25+
import java.net.InetSocketAddress;
26+
import java.net.ServerSocket;
27+
import java.security.KeyPair;
28+
import java.security.KeyPairGenerator;
29+
import java.security.KeyStore;
30+
import java.security.SecureRandom;
31+
import java.security.cert.X509Certificate;
32+
import java.security.spec.ECGenParameterSpec;
1733
import java.time.Duration;
34+
import java.time.Instant;
35+
import java.time.temporal.ChronoUnit;
36+
import java.util.Date;
37+
import java.util.function.Supplier;
38+
import javax.net.ssl.KeyManagerFactory;
39+
import javax.net.ssl.SSLContext;
40+
import org.bouncycastle.asn1.x500.X500NameBuilder;
41+
import org.bouncycastle.asn1.x500.style.BCStyle;
42+
import org.bouncycastle.cert.X509CertificateHolder;
43+
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
44+
import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder;
45+
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
1846

1947
public final class OAuth2TestUtils {
2048

49+
private static final char[] KEY_STORE_PASSWORD = "password".toCharArray();
50+
2151
private OAuth2TestUtils() {}
2252

2353
public static String sampleJsonToken(String accessToken, Duration expiresIn) {
@@ -32,4 +62,144 @@ public static String sampleJsonToken(String accessToken, Duration expiresIn) {
3262
return json.replace("{accessToken}", accessToken)
3363
.replace("{expiresIn}", expiresIn.toSeconds() + "");
3464
}
65+
66+
public static int randomNetworkPort() throws IOException {
67+
ServerSocket socket = new ServerSocket();
68+
socket.bind(null);
69+
int port = socket.getLocalPort();
70+
socket.close();
71+
return port;
72+
}
73+
74+
public static Duration waitAtMost(
75+
Duration timeout,
76+
Duration waitTime,
77+
CallableBooleanSupplier condition,
78+
Supplier<String> message)
79+
throws Exception {
80+
if (condition.getAsBoolean()) {
81+
return Duration.ZERO;
82+
}
83+
Duration waitedTime = Duration.ZERO;
84+
Exception exception = null;
85+
while (waitedTime.compareTo(timeout) <= 0) {
86+
Thread.sleep(waitTime.toMillis());
87+
waitedTime = waitedTime.plus(waitTime);
88+
try {
89+
if (condition.getAsBoolean()) {
90+
return waitedTime;
91+
}
92+
exception = null;
93+
} catch (Exception e) {
94+
exception = e;
95+
}
96+
}
97+
String msg;
98+
if (message == null) {
99+
msg = "Waited " + timeout.getSeconds() + " second(s), condition never got true";
100+
} else {
101+
msg = "Waited " + timeout.getSeconds() + " second(s), " + message.get();
102+
}
103+
if (exception == null) {
104+
fail(msg);
105+
} else {
106+
fail(msg, exception);
107+
}
108+
return waitedTime;
109+
}
110+
111+
public static Duration waitAtMost(
112+
Duration timeout, Duration waitTime, CallableBooleanSupplier condition) throws Exception {
113+
return waitAtMost(timeout, waitTime, condition, null);
114+
}
115+
116+
public static HttpServer startServer(int port, String path, HttpHandler handler) {
117+
return startServer(port, path, null, handler);
118+
}
119+
120+
public static HttpServer startServer(
121+
int port, String path, KeyStore keyStore, HttpHandler handler) {
122+
HttpServer server;
123+
try {
124+
if (keyStore != null) {
125+
KeyManagerFactory keyManagerFactory =
126+
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
127+
keyManagerFactory.init(keyStore, KEY_STORE_PASSWORD);
128+
SSLContext sslContext = SSLContext.getInstance("TLS");
129+
sslContext.init(keyManagerFactory.getKeyManagers(), null, null);
130+
server = HttpsServer.create(new InetSocketAddress(port), 0);
131+
((HttpsServer) server).setHttpsConfigurator(new HttpsConfigurator(sslContext));
132+
} else {
133+
server = HttpServer.create(new InetSocketAddress(port), 0);
134+
}
135+
server.createContext(path, handler);
136+
server.start();
137+
return server;
138+
} catch (Exception e) {
139+
throw new RuntimeException(e);
140+
}
141+
}
142+
143+
public static KeyStore generateKeyPair() {
144+
try {
145+
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
146+
keyStore.load(null, KEY_STORE_PASSWORD);
147+
148+
KeyPairGenerator kpg = KeyPairGenerator.getInstance("EC");
149+
ECGenParameterSpec spec = new ECGenParameterSpec("secp521r1");
150+
kpg.initialize(spec);
151+
152+
KeyPair kp = kpg.generateKeyPair();
153+
154+
JcaX509v3CertificateBuilder certificateBuilder =
155+
new JcaX509v3CertificateBuilder(
156+
new X500NameBuilder().addRDN(BCStyle.CN, "localhost").build(),
157+
BigInteger.valueOf(new SecureRandom().nextInt()),
158+
Date.from(Instant.now().minus(10, ChronoUnit.DAYS)),
159+
Date.from(Instant.now().plus(10, ChronoUnit.DAYS)),
160+
new X500NameBuilder().addRDN(BCStyle.CN, "localhost").build(),
161+
kp.getPublic());
162+
163+
X509CertificateHolder certificateHolder =
164+
certificateBuilder.build(
165+
new JcaContentSignerBuilder("SHA256withECDSA").build(kp.getPrivate()));
166+
167+
X509Certificate certificate =
168+
new JcaX509CertificateConverter().getCertificate(certificateHolder);
169+
170+
keyStore.setKeyEntry(
171+
"localhost", kp.getPrivate(), KEY_STORE_PASSWORD, new X509Certificate[] {certificate});
172+
173+
return keyStore;
174+
} catch (Exception e) {
175+
throw new RuntimeException(e);
176+
}
177+
}
178+
179+
public static <A, B> Pair<A, B> pair(A v1, B v2) {
180+
return new Pair<>(v1, v2);
181+
}
182+
183+
public interface CallableBooleanSupplier {
184+
boolean getAsBoolean() throws Exception;
185+
}
186+
187+
public static class Pair<A, B> {
188+
189+
private final A v1;
190+
private final B v2;
191+
192+
private Pair(A v1, B v2) {
193+
this.v1 = v1;
194+
this.v2 = v2;
195+
}
196+
197+
public A v1() {
198+
return this.v1;
199+
}
200+
201+
public B v2() {
202+
return this.v2;
203+
}
204+
}
35205
}

src/test/java/com/rabbitmq/stream/oauth2/TokenCredentialsManagerTest.java

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,23 @@
1414
1515
package com.rabbitmq.stream.oauth2;
1616

17-
import static com.rabbitmq.stream.impl.TestUtils.waitAtMost;
17+
import static com.rabbitmq.stream.oauth2.OAuth2TestUtils.pair;
18+
import static com.rabbitmq.stream.oauth2.OAuth2TestUtils.waitAtMost;
1819
import static com.rabbitmq.stream.oauth2.TokenCredentialsManager.DEFAULT_REFRESH_DELAY_STRATEGY;
19-
import static com.rabbitmq.stream.oauth2.Tuples.pair;
2020
import static java.time.Duration.ofMillis;
2121
import static java.time.Duration.ofSeconds;
2222
import static java.util.stream.Collectors.toList;
2323
import static java.util.stream.IntStream.range;
2424
import static org.assertj.core.api.Assertions.assertThat;
2525
import static org.mockito.Mockito.when;
2626

27-
import com.rabbitmq.stream.impl.Assertions;
28-
import com.rabbitmq.stream.impl.TestUtils;
29-
import com.rabbitmq.stream.impl.TestUtils.Sync;
3027
import java.time.Duration;
3128
import java.time.Instant;
3229
import java.util.List;
30+
import java.util.concurrent.CountDownLatch;
3331
import java.util.concurrent.Executors;
3432
import java.util.concurrent.ScheduledExecutorService;
33+
import java.util.concurrent.TimeUnit;
3534
import java.util.concurrent.atomic.AtomicInteger;
3635
import java.util.function.Function;
3736
import org.junit.jupiter.api.AfterEach;
@@ -73,17 +72,22 @@ void refreshShouldStopOnceUnregistered() throws InterruptedException {
7372
this.requester, this.scheduledExecutorService, DEFAULT_REFRESH_DELAY_STRATEGY);
7473
int expectedRefreshCount = 3;
7574
AtomicInteger refreshCount = new AtomicInteger();
76-
Sync refreshSync = TestUtils.sync(expectedRefreshCount);
75+
CountDownLatch refreshSync = new CountDownLatch(expectedRefreshCount);
7776
CredentialsManager.Registration registration =
7877
credentials.register(
7978
"",
8079
(u, p) -> {
8180
refreshCount.incrementAndGet();
82-
refreshSync.down();
81+
refreshSync.countDown();
8382
});
8483
registration.connect(connectionCallback(() -> {}));
8584
assertThat(requestCount).hasValue(1);
86-
Assertions.assertThat(refreshSync).completes();
85+
try {
86+
assertThat(refreshSync.await(ofSeconds(10).toMillis(), TimeUnit.MILLISECONDS)).isTrue();
87+
} catch (InterruptedException e) {
88+
Thread.currentThread().interrupt();
89+
throw new RuntimeException(e);
90+
}
8791
assertThat(requestCount).hasValue(expectedRefreshCount + 1);
8892
registration.close();
8993
assertThat(refreshCount).hasValue(expectedRefreshCount);
@@ -106,24 +110,33 @@ void severalRegistrationsShouldBeRefreshed() throws Exception {
106110
int expectedRefreshCountPerConnection = 3;
107111
int connectionCount = 10;
108112
AtomicInteger totalRefreshCount = new AtomicInteger();
109-
List<Tuples.Pair<CredentialsManager.Registration, Sync>> registrations =
113+
List<OAuth2TestUtils.Pair<CredentialsManager.Registration, CountDownLatch>> registrations =
110114
range(0, connectionCount)
111115
.mapToObj(
112116
ignored -> {
113-
Sync sync = TestUtils.sync(expectedRefreshCountPerConnection);
117+
CountDownLatch sync = new CountDownLatch(expectedRefreshCountPerConnection);
114118
CredentialsManager.Registration r =
115119
credentials.register(
116120
"",
117121
(username, password) -> {
118122
totalRefreshCount.incrementAndGet();
119-
sync.down();
123+
sync.countDown();
120124
});
121125
return pair(r, sync);
122126
})
123127
.collect(toList());
124128

125129
registrations.forEach(r -> r.v1().connect(connectionCallback(() -> {})));
126-
registrations.forEach(r -> Assertions.assertThat(r.v2()).completes());
130+
for (OAuth2TestUtils.Pair<CredentialsManager.Registration, CountDownLatch> registrationPair :
131+
registrations) {
132+
try {
133+
assertThat(registrationPair.v2().await(ofSeconds(10).toMillis(), TimeUnit.MILLISECONDS))
134+
.isTrue();
135+
} catch (InterruptedException e) {
136+
Thread.currentThread().interrupt();
137+
throw new RuntimeException(e);
138+
}
139+
}
127140
// all connections have been refreshed once
128141
int refreshCountSnapshot = totalRefreshCount.get();
129142
assertThat(refreshCountSnapshot).isEqualTo(connectionCount * expectedRefreshCountPerConnection);

src/test/java/com/rabbitmq/stream/oauth2/Tuples.java

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)