Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions core/src/main/java/tech/ydb/core/grpc/GrpcTransportBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
import com.google.common.net.HostAndPort;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.ManagedChannel;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;

import tech.ydb.auth.AuthRpcProvider;
import tech.ydb.auth.NopAuthProvider;
import tech.ydb.core.impl.YdbSchedulerFactory;
import tech.ydb.core.impl.YdbTransportImpl;
import tech.ydb.core.impl.auth.GrpcAuthRpc;
import tech.ydb.core.impl.pool.DefaultChannelFactory;
import tech.ydb.core.impl.pool.ChannelFactoryLoader;
import tech.ydb.core.impl.pool.ManagedChannelFactory;
import tech.ydb.core.utils.Version;

Expand Down Expand Up @@ -69,7 +68,7 @@ public enum InitMode {

private byte[] cert = null;
private boolean useTLS = false;
private ManagedChannelFactory.Builder channelFactoryBuilder = DefaultChannelFactory::build;
private ManagedChannelFactory.Builder channelFactoryBuilder = null;
private Supplier<ScheduledExecutorService> schedulerFactory = YdbSchedulerFactory::createScheduler;
private String localDc;
private BalancingSettings balancingSettings;
Expand Down Expand Up @@ -177,6 +176,10 @@ public boolean useDefaultGrpcResolver() {
}

public ManagedChannelFactory getManagedChannelFactory() {
if (channelFactoryBuilder == null) {
channelFactoryBuilder = ChannelFactoryLoader.load();
}

return channelFactoryBuilder.buildFactory(this);
}

Expand All @@ -193,18 +196,20 @@ public GrpcTransportBuilder withChannelFactoryBuilder(ManagedChannelFactory.Buil
}

/**
* Set a custom initialization of {@link NettyChannelBuilder} <br>
* Set a custom initialization of {@link io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder} <br>
* This method is deprecated. Use
* {@link GrpcTransportBuilder#withChannelFactoryBuilder(tech.ydb.core.impl.pool.ManagedChannelFactory.Builder)}
* instead
*
* @param channelInitializer custom NettyChannelBuilder initializator
* @param ci custom NettyChannelBuilder initializator
* @return this
* @deprecated
*/
@Deprecated
public GrpcTransportBuilder withChannelInitializer(Consumer<NettyChannelBuilder> channelInitializer) {
this.channelFactoryBuilder = gtb -> DefaultChannelFactory.build(gtb, channelInitializer);
public GrpcTransportBuilder withChannelInitializer(
Consumer<io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder> ci
) {
this.channelFactoryBuilder = tech.ydb.core.impl.pool.ShadedNettyChannelFactory.withInterceptor(ci);
return this;
}

Expand Down
44 changes: 24 additions & 20 deletions core/src/main/java/tech/ydb/core/impl/YdbDiscovery.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tech.ydb.core.grpc.GrpcTransport;
import tech.ydb.core.impl.pool.EndpointRecord;
import tech.ydb.core.operation.OperationBinder;
import tech.ydb.core.utils.Async;
import tech.ydb.core.utils.FutureTools;
import tech.ydb.proto.discovery.DiscoveryProtos;
import tech.ydb.proto.discovery.v1.DiscoveryServiceGrpc;

Expand Down Expand Up @@ -140,26 +140,30 @@ private void tick() {

private void runDiscovery() {
lastUpdateTime = handler.instant();
final GrpcTransport transport = handler.createDiscoveryTransport();
try {
logger.debug("execute list endpoints on {} with timeout {}", transport, discoveryTimeout);
DiscoveryProtos.ListEndpointsRequest request = DiscoveryProtos.ListEndpointsRequest.newBuilder()
.setDatabase(discoveryDatabase)
.build();

GrpcRequestSettings grpcSettings = GrpcRequestSettings.newBuilder()
.withDeadline(discoveryTimeout)
.build();

transport.unaryCall(DiscoveryServiceGrpc.getListEndpointsMethod(), grpcSettings, request)
.whenComplete((res, ex) -> transport.close()) // close transport for any result
.thenApply(OperationBinder.bindSync(
DiscoveryProtos.ListEndpointsResponse::getOperation,
DiscoveryProtos.ListEndpointsResult.class
))
.whenComplete(this::handleDiscoveryResult);
final GrpcTransport transport = handler.createDiscoveryTransport();
try {
logger.debug("execute list endpoints on {} with timeout {}", transport, discoveryTimeout);
DiscoveryProtos.ListEndpointsRequest request = DiscoveryProtos.ListEndpointsRequest.newBuilder()
.setDatabase(discoveryDatabase)
.build();

GrpcRequestSettings grpcSettings = GrpcRequestSettings.newBuilder()
.withDeadline(discoveryTimeout)
.build();

transport.unaryCall(DiscoveryServiceGrpc.getListEndpointsMethod(), grpcSettings, request)
.whenComplete((res, ex) -> transport.close()) // close transport for any result
.thenApply(OperationBinder.bindSync(
DiscoveryProtos.ListEndpointsResponse::getOperation,
DiscoveryProtos.ListEndpointsResult.class
))
.whenComplete(this::handleDiscoveryResult);
} catch (Throwable th) {
transport.close();
throw th;
}
} catch (Throwable th) {
transport.close();
handleDiscoveryResult(null, th);
}
}
Expand All @@ -183,7 +187,7 @@ private void handleOk(String selfLocation, List<EndpointRecord> endpoints) {

private void handleDiscoveryResult(Result<DiscoveryProtos.ListEndpointsResult> response, Throwable th) {
if (th != null) {
Throwable cause = Async.unwrapCompletionException(th);
Throwable cause = FutureTools.unwrapCompletionException(th);
logger.warn("couldn't perform discovery with exception", cause);
handleThrowable(cause);
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package tech.ydb.core.impl.pool;


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**
*
* @author Aleksandr Gorshenin
*/
public class ChannelFactoryLoader {
private static final Logger logger = LoggerFactory.getLogger(ChannelFactoryLoader.class);

private ChannelFactoryLoader() { }

public static ManagedChannelFactory.Builder load() {
return FactoryLoader.factory;
}

private static class FactoryLoader {
private static final String SHADED_DEPS = "io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder";
private static final String NETTY_DEPS = "io.grpc.netty.NettyChannelBuilder";

private static ManagedChannelFactory.Builder factory;

static {
boolean ok = tryLoad(SHADED_DEPS, ShadedNettyChannelFactory.build())
|| tryLoad(NETTY_DEPS, NettyChannelFactory.build());
if (!ok) {
throw new IllegalStateException("Cannot load any ManagedChannelFactory!! "
+ "Classpath must contain grpc-netty or grpc-netty-shaded");
}
}

private static boolean tryLoad(String name, ManagedChannelFactory.Builder f) {
try {
Class.forName(name);
logger.info("class {} is found, use {}", name, f);
factory = f;
return true;
} catch (ClassNotFoundException ex) {
logger.info("class {} is not found", name);
return false;
}
}
}
}
17 changes: 11 additions & 6 deletions core/src/main/java/tech/ydb/core/impl/pool/GrpcChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ public class GrpcChannel {
private final ReadyWatcher readyWatcher;

public GrpcChannel(EndpointRecord endpoint, ManagedChannelFactory factory) {
logger.debug("Creating grpc channel with {}", endpoint);
this.endpoint = endpoint;
this.channel = factory.newManagedChannel(endpoint.getHost(), endpoint.getPort());
this.connectTimeoutMs = factory.getConnectTimeoutMs();
this.readyWatcher = new ReadyWatcher();
this.readyWatcher.checkState();
try {
logger.debug("Creating grpc channel with {}", endpoint);
this.endpoint = endpoint;
this.channel = factory.newManagedChannel(endpoint.getHost(), endpoint.getPort());
this.connectTimeoutMs = factory.getConnectTimeoutMs();
this.readyWatcher = new ReadyWatcher();
this.readyWatcher.checkState();
} catch (Throwable th) {
logger.error("cannot create channel", th);
throw new RuntimeException("cannot create channel", th);
}
}

public EndpointRecord getEndpoint() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.io.ByteArrayInputStream;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import javax.net.ssl.SSLException;

Expand Down Expand Up @@ -39,7 +40,7 @@ public class NettyChannelFactory implements ManagedChannelFactory {
private final boolean useDefaultGrpcResolver;
private final Long grpcKeepAliveTimeMillis;

public NettyChannelFactory(GrpcTransportBuilder builder) {
private NettyChannelFactory(GrpcTransportBuilder builder) {
this.database = builder.getDatabase();
this.version = builder.getVersionString();
this.useTLS = builder.getUseTls();
Expand Down Expand Up @@ -120,4 +121,29 @@ private SslContext createSslContext() {
throw new RuntimeException("cannot create ssl context", e);
}
}

public static ManagedChannelFactory.Builder build() {
return new Builder() {
@Override
public ManagedChannelFactory buildFactory(GrpcTransportBuilder builder) {
return new NettyChannelFactory(builder);
}

@Override
public String toString() {
return "NettyChannelFactory";
}
};
}

public static ManagedChannelFactory.Builder withInterceptor(Consumer<NettyChannelBuilder> ci) {
return builder -> new NettyChannelFactory(builder) {
@Override
protected void configure(NettyChannelBuilder channelBuilder) {
if (ci != null) {
ci.accept(channelBuilder);
}
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
* @author Nikolay Perfilov
* @author Aleksandr Gorshenin
*/
public class DefaultChannelFactory implements ManagedChannelFactory {
public class ShadedNettyChannelFactory implements ManagedChannelFactory {
static final int INBOUND_MESSAGE_SIZE = 64 << 20; // 64 MiB
static final String DEFAULT_BALANCER_POLICY = "round_robin";

Expand All @@ -40,7 +40,7 @@ public class DefaultChannelFactory implements ManagedChannelFactory {
private final boolean useDefaultGrpcResolver;
private final Long grpcKeepAliveTimeMillis;

private DefaultChannelFactory(GrpcTransportBuilder builder) {
public ShadedNettyChannelFactory(GrpcTransportBuilder builder) {
this.database = builder.getDatabase();
this.version = builder.getVersionString();
this.useTLS = builder.getUseTls();
Expand Down Expand Up @@ -122,12 +122,22 @@ private SslContext createSslContext() {
}
}

public static ManagedChannelFactory build(GrpcTransportBuilder builder) {
return new DefaultChannelFactory(builder);
public static ManagedChannelFactory.Builder build() {
return new Builder() {
@Override
public ManagedChannelFactory buildFactory(GrpcTransportBuilder builder) {
return new ShadedNettyChannelFactory(builder);
}

@Override
public String toString() {
return "ShadedNettyChannelFactory";
}
};
}

public static ManagedChannelFactory build(GrpcTransportBuilder builder, Consumer<NettyChannelBuilder> ci) {
return new DefaultChannelFactory(builder) {
public static ManagedChannelFactory.Builder withInterceptor(Consumer<NettyChannelBuilder> ci) {
return builder -> new ShadedNettyChannelFactory(builder) {
@Override
protected void configure(NettyChannelBuilder channelBuilder) {
if (ci != null) {
Expand Down
81 changes: 0 additions & 81 deletions core/src/main/java/tech/ydb/core/utils/Async.java

This file was deleted.

Loading