Skip to content

Commit a133f3d

Browse files
authored
Merge pull request #322 from dcherednik/ip_discovery
Support for using ip address in discovery response.
2 parents 14c04e1 + ce9981a commit a133f3d

14 files changed

+71
-36
lines changed

core/src/main/java/tech/ydb/core/impl/YdbDiscovery.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import tech.ydb.core.operation.OperationBinder;
2424
import tech.ydb.core.utils.FutureTools;
2525
import tech.ydb.proto.discovery.DiscoveryProtos;
26+
import tech.ydb.proto.discovery.DiscoveryProtos.EndpointInfo;
2627
import tech.ydb.proto.discovery.v1.DiscoveryServiceGrpc;
2728

2829
/**
@@ -185,6 +186,21 @@ private void handleOk(String selfLocation, List<EndpointRecord> endpoints) {
185186
}
186187
}
187188

189+
private static String createAddress(EndpointInfo e) {
190+
String addr;
191+
if (e.getIpV6Count() > 0 && e.getIpV6(0) != null && !e.getIpV6(0).isEmpty()) {
192+
addr = e.getIpV6(0);
193+
} else if (e.getIpV4Count() > 0 && e.getIpV4(0) != null && !e.getIpV4(0).isEmpty()) {
194+
addr = e.getIpV4(0);
195+
} else {
196+
addr = e.getAddress();
197+
}
198+
199+
logger.debug("address {} will be used to connect to node {}", addr, e.getAddress());
200+
201+
return addr;
202+
}
203+
188204
private void handleDiscoveryResult(Result<DiscoveryProtos.ListEndpointsResult> response, Throwable th) {
189205
if (th != null) {
190206
Throwable cause = FutureTools.unwrapCompletionException(th);
@@ -202,7 +218,8 @@ private void handleDiscoveryResult(Result<DiscoveryProtos.ListEndpointsResult> r
202218
}
203219

204220
List<EndpointRecord> records = result.getEndpointsList().stream()
205-
.map(e -> new EndpointRecord(e.getAddress(), e.getPort(), e.getNodeId(), e.getLocation()))
221+
.map(e -> new EndpointRecord(createAddress(e), e.getPort(), e.getNodeId(), e.getLocation(),
222+
e.getSslTargetNameOverride()))
206223
.collect(Collectors.toList());
207224

208225
logger.debug("successfully received ListEndpoints result with {} endpoints", records.size());

core/src/main/java/tech/ydb/core/impl/pool/EndpointRecord.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,35 @@ public class EndpointRecord {
99
private final String host;
1010
private final String hostAndPort;
1111
private final String locationDC;
12+
private final String authority;
1213
private final int port;
1314
private final int nodeId;
1415

15-
public EndpointRecord(String host, int port, int nodeId, String locationDC) {
16+
public EndpointRecord(String host, int port, int nodeId, String locationDC, String authority) {
1617
this.host = Objects.requireNonNull(host);
1718
this.port = port;
1819
this.hostAndPort = host + ":" + port;
1920
this.nodeId = nodeId;
2021
this.locationDC = locationDC;
22+
if (authority != null && !authority.isEmpty()) {
23+
this.authority = authority;
24+
} else {
25+
this.authority = null;
26+
}
2127
}
2228

2329
public EndpointRecord(String host, int port) {
24-
this(host, port, 0, null);
30+
this(host, port, 0, null, null);
2531
}
2632

2733
public String getHost() {
2834
return host;
2935
}
3036

37+
public String getAuthority() {
38+
return authority;
39+
}
40+
3141
public int getPort() {
3242
return port;
3343
}
@@ -46,6 +56,7 @@ public String getLocation() {
4656

4757
@Override
4858
public String toString() {
49-
return "Endpoint{host=" + host + ", port=" + port + ", node=" + nodeId + ", location=" + locationDC + "}";
59+
return "Endpoint{host=" + host + ", port=" + port + ", node=" + nodeId +
60+
", location=" + locationDC + ", overrideAuthority=" + authority + "}";
5061
}
5162
}

core/src/main/java/tech/ydb/core/impl/pool/GrpcChannel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ public GrpcChannel(EndpointRecord endpoint, ManagedChannelFactory factory) {
2828
try {
2929
logger.debug("Creating grpc channel with {}", endpoint);
3030
this.endpoint = endpoint;
31-
this.channel = factory.newManagedChannel(endpoint.getHost(), endpoint.getPort());
31+
this.channel = factory.newManagedChannel(endpoint.getHost(), endpoint.getPort(),
32+
endpoint.getAuthority());
3233
this.connectTimeoutMs = factory.getConnectTimeoutMs();
3334
this.readyWatcher = new ReadyWatcher();
3435
this.readyWatcher.checkState();

core/src/main/java/tech/ydb/core/impl/pool/ManagedChannelFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ interface Builder {
1313
ManagedChannelFactory buildFactory(GrpcTransportBuilder builder);
1414
}
1515

16-
ManagedChannel newManagedChannel(String host, int port);
16+
ManagedChannel newManagedChannel(String host, int port, String authority);
1717

1818
long getConnectTimeoutMs();
1919
}

core/src/main/java/tech/ydb/core/impl/pool/NettyChannelFactory.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,17 @@ public long getConnectTimeoutMs() {
5858

5959
@SuppressWarnings("deprecation")
6060
@Override
61-
public ManagedChannel newManagedChannel(String host, int port) {
61+
public ManagedChannel newManagedChannel(String host, int port, String sslHostOverride) {
6262
NettyChannelBuilder channelBuilder = NettyChannelBuilder
6363
.forAddress(host, port);
6464

6565
if (useTLS) {
6666
channelBuilder
6767
.negotiationType(NegotiationType.TLS)
6868
.sslContext(createSslContext());
69+
if (sslHostOverride != null) {
70+
channelBuilder.overrideAuthority(sslHostOverride);
71+
}
6972
} else {
7073
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
7174
}

core/src/main/java/tech/ydb/core/impl/pool/ShadedNettyChannelFactory.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,17 @@ public long getConnectTimeoutMs() {
5858

5959
@SuppressWarnings("deprecation")
6060
@Override
61-
public ManagedChannel newManagedChannel(String host, int port) {
61+
public ManagedChannel newManagedChannel(String host, int port, String sslHostOverride) {
6262
NettyChannelBuilder channelBuilder = NettyChannelBuilder
6363
.forAddress(host, port);
6464

6565
if (useTLS) {
6666
channelBuilder
6767
.negotiationType(NegotiationType.TLS)
6868
.sslContext(createSslContext());
69+
if (sslHostOverride != null) {
70+
channelBuilder.overrideAuthority(sslHostOverride);
71+
}
6972
} else {
7073
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
7174
}

core/src/test/java/tech/ydb/core/impl/YdbDiscoveryTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public void setUp() throws InterruptedException {
4040
Mockito.when(channel.shutdownNow()).thenReturn(channel);
4141
Mockito.when(channel.awaitTermination(Mockito.anyLong(), Mockito.any())).thenReturn(true);
4242

43-
Mockito.when(channelFactory.newManagedChannel(Mockito.any(), Mockito.anyInt())).thenReturn(channel);
43+
Mockito.when(channelFactory.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull())).thenReturn(channel);
4444
}
4545

4646
private <T extends Throwable> T checkFutureException(CompletableFuture<Boolean> f, String message, Class<T> clazz) {

core/src/test/java/tech/ydb/core/impl/YdbTransportImplTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ public void setUp() throws InterruptedException {
5151
Mockito.when(transportChannel.shutdownNow()).thenReturn(transportChannel);
5252
Mockito.when(transportChannel.awaitTermination(Mockito.anyLong(), Mockito.any())).thenReturn(true);
5353

54-
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("mocked"), Mockito.eq(2136)))
54+
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("mocked"), Mockito.eq(2136), Mockito.isNull()))
5555
.thenReturn(discoveryChannel);
56-
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("node"), Mockito.eq(2136)))
56+
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("node"), Mockito.eq(2136), Mockito.isNull()))
5757
.thenReturn(transportChannel);
5858
}
5959

core/src/test/java/tech/ydb/core/impl/pool/DefaultChannelFactoryTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void defaultParams() {
7676
channelStaticMock.verify(FOR_ADDRESS, times(0));
7777

7878
Assert.assertEquals(30_000l, factory.getConnectTimeoutMs());
79-
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
79+
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));
8080

8181
channelStaticMock.verify(FOR_ADDRESS, times(1));
8282

@@ -100,7 +100,7 @@ public void defaultSslFactory() {
100100
channelStaticMock.verify(FOR_ADDRESS, times(0));
101101

102102
Assert.assertEquals(60000l, factory.getConnectTimeoutMs());
103-
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
103+
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));
104104

105105
channelStaticMock.verify(FOR_ADDRESS, times(1));
106106

@@ -124,7 +124,7 @@ public void customChannelInitializer() {
124124

125125
channelStaticMock.verify(FOR_ADDRESS, times(0));
126126

127-
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
127+
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));
128128

129129
channelStaticMock.verify(FOR_ADDRESS, times(1));
130130

@@ -150,7 +150,7 @@ public void customSslFactory() throws CertificateException, IOException {
150150
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);
151151

152152
Assert.assertEquals(4000l, factory.getConnectTimeoutMs());
153-
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
153+
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));
154154

155155
} finally {
156156
selfSignedCert.delete();
@@ -176,7 +176,7 @@ public void invalidSslCert() {
176176
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);
177177

178178
RuntimeException ex = Assert.assertThrows(RuntimeException.class,
179-
() -> factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
179+
() -> factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));
180180

181181
Assert.assertEquals("cannot create ssl context", ex.getMessage());
182182
Assert.assertNotNull(ex.getCause());

core/src/test/java/tech/ydb/core/impl/pool/EndpointPoolTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ public void nodePessimizationTest() {
208208
check(pool.getEndpoint(2)).hostname("n2.ydb.tech").nodeID(2).port(12342);
209209

210210
// Pessimize unknown nodes - nothing is changed
211-
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12341, 2, null));
212-
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12342, 2, null));
211+
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12341, 2, null, null));
212+
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12342, 2, null, null));
213213
pool.pessimizeEndpoint(null);
214214
check(pool).records(5).knownNodes(5).needToReDiscovery(false).bestEndpointsCount(4);
215215

@@ -553,6 +553,6 @@ private static List<EndpointRecord> list(EndpointRecord... records) {
553553
}
554554

555555
private static EndpointRecord endpoint(int nodeID, String hostname, int port, String location) {
556-
return new EndpointRecord(hostname, port, nodeID, location);
556+
return new EndpointRecord(hostname, port, nodeID, location, null);
557557
}
558558
}

0 commit comments

Comments
 (0)