Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Ydb.Sdk/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
- Disable Discovery mode: skip discovery step and client balancing and use connection to start endpoint ([#420](https://github.com/ydb-platform/ydb-dotnet-sdk/issues/420)).

## v0.17.0

- Shutdown channels which are removed from the EndpointPool after discovery calls.
Expand Down
66 changes: 42 additions & 24 deletions src/Ydb.Sdk/src/Ado/YdbConnectionStringBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Ydb.Sdk.Auth;
using Ydb.Sdk.Transport;

namespace Ydb.Sdk.Ado;

Expand Down Expand Up @@ -32,6 +34,7 @@ private void InitDefaultValues()
_enableMultipleHttp2Connections = false;
_maxSendMessageSize = GrpcDefaultSettings.MaxSendMessageSize;
_maxReceiveMessageSize = GrpcDefaultSettings.MaxReceiveMessageSize;
_disableDiscovery = false;
}

public string Host
Expand Down Expand Up @@ -213,6 +216,18 @@ public int MaxReceiveMessageSize

private int _maxReceiveMessageSize;

public bool DisableDiscovery
{
get => _disableDiscovery;
set
{
_disableDiscovery = value;
SaveValue(nameof(DisableDiscovery), value);
}
}

private bool _disableDiscovery;

public ILoggerFactory? LoggerFactory { get; init; }

public ICredentialsProvider? CredentialsProvider { get; init; }
Expand Down Expand Up @@ -257,33 +272,34 @@ public override object this[string keyword]

private string Endpoint => $"{(UseTls ? "grpcs" : "grpc")}://{Host}:{Port}";

internal Task<Driver> BuildDriver()
internal async Task<IDriver> BuildDriver()
{
var cert = RootCertificate != null ? X509Certificate.CreateFromCertFile(RootCertificate) : null;
var driverConfig = new DriverConfig(
endpoint: Endpoint,
database: Database,
credentials: CredentialsProvider,
customServerCertificate: cert,
customServerCertificates: ServerCertificates
)
{
KeepAlivePingDelay = KeepAlivePingDelay == 0
? Timeout.InfiniteTimeSpan
: TimeSpan.FromSeconds(KeepAlivePingDelay),
KeepAlivePingTimeout = KeepAlivePingTimeout == 0
? Timeout.InfiniteTimeSpan
: TimeSpan.FromSeconds(KeepAlivePingTimeout),
User = User,
Password = Password,
EnableMultipleHttp2Connections = EnableMultipleHttp2Connections,
MaxSendMessageSize = MaxSendMessageSize,
MaxReceiveMessageSize = MaxReceiveMessageSize
};
var loggerFactory = LoggerFactory ?? NullLoggerFactory.Instance;

return Driver.CreateInitialized(
new DriverConfig(
endpoint: Endpoint,
database: Database,
credentials: CredentialsProvider,
customServerCertificate: cert,
customServerCertificates: ServerCertificates
)
{
KeepAlivePingDelay = KeepAlivePingDelay == 0
? Timeout.InfiniteTimeSpan
: TimeSpan.FromSeconds(KeepAlivePingDelay),
KeepAlivePingTimeout = KeepAlivePingTimeout == 0
? Timeout.InfiniteTimeSpan
: TimeSpan.FromSeconds(KeepAlivePingTimeout),
User = User,
Password = Password,
EnableMultipleHttp2Connections = EnableMultipleHttp2Connections,
MaxSendMessageSize = MaxSendMessageSize,
MaxReceiveMessageSize = MaxReceiveMessageSize
},
LoggerFactory
);
return DisableDiscovery
? new DirectGrpcChannelDriver(driverConfig, loggerFactory)
: await Driver.CreateInitialized(driverConfig, loggerFactory);
}

public override void Clear()
Expand Down Expand Up @@ -369,6 +385,8 @@ static YdbConnectionOption()
AddOption(new YdbConnectionOption<int>(IntExtractor, (builder, maxReceiveMessageSize) =>
builder.MaxReceiveMessageSize = maxReceiveMessageSize),
"MaxReceiveMessageSize", "Max Receive Message Size");
AddOption(new YdbConnectionOption<bool>(BoolExtractor, (builder, disableDiscovery) =>
builder.DisableDiscovery = disableDiscovery), "DisableDiscovery", "Disable Discovery");
}

private static void AddOption(YdbConnectionOption option, params string[] keys)
Expand Down
36 changes: 5 additions & 31 deletions src/Ydb.Sdk/src/Driver.cs
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
using System.Collections.Immutable;
using Grpc.Core;
using Grpc.Net.Client;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Ydb.Discovery;
using Ydb.Discovery.V1;
using Ydb.Sdk.Auth;
using Ydb.Sdk.Pool;
using Ydb.Sdk.Services.Auth;

namespace Ydb.Sdk;

public sealed class Driver : BaseDriver
{
private const int AttemptDiscovery = 10;

private readonly GrpcChannelFactory _grpcChannelFactory;
private readonly EndpointPool _endpointPool;
private readonly ChannelPool<GrpcChannel> _channelPool;

internal string Database => Config.Database;

Expand All @@ -26,19 +21,7 @@ public Driver(DriverConfig config, ILoggerFactory? loggerFactory = null)
(loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<Driver>()
)
{
_grpcChannelFactory = new GrpcChannelFactory(LoggerFactory, config);
_endpointPool = new EndpointPool(LoggerFactory.CreateLogger<EndpointPool>());
_channelPool = new ChannelPool<GrpcChannel>(
LoggerFactory.CreateLogger<ChannelPool<GrpcChannel>>(),
_grpcChannelFactory
);

CredentialsProvider = Config.User != null
? new CachedCredentialsProvider(
new StaticCredentialsAuthClient(config, _grpcChannelFactory, LoggerFactory),
LoggerFactory
)
: Config.Credentials;
_endpointPool = new EndpointPool(LoggerFactory);
}

public static async Task<Driver> CreateInitialized(DriverConfig config, ILoggerFactory? loggerFactory = null)
Expand All @@ -48,8 +31,6 @@ public static async Task<Driver> CreateInitialized(DriverConfig config, ILoggerF
return driver;
}

protected override ValueTask InternalDispose() => _channelPool.DisposeAsync();

public async Task Initialize()
{
Logger.LogInformation("Started initial endpoint discovery");
Expand Down Expand Up @@ -78,18 +59,13 @@ public async Task Initialize()
}
}

await Task.Delay(TimeSpan.FromMilliseconds(i * 200)); // await 0 ms, 200 ms, 400ms, .. 1.8 sec
await Task.Delay(TimeSpan.FromMilliseconds(i * 200)); // await 0 ms, 200 ms, 400ms, ... 1.8 sec
}

throw new InitializationFailureException("Error during initial endpoint discovery");
}

protected override (string, GrpcChannel) GetChannel(long nodeId)
{
var endpoint = _endpointPool.GetEndpoint(nodeId);

return (endpoint, _channelPool.GetChannel(endpoint));
}
protected override string GetEndpoint(long nodeId) => _endpointPool.GetEndpoint(nodeId);

protected override void OnRpcError(string endpoint, RpcException e)
{
Expand All @@ -114,11 +90,9 @@ Grpc.Core.StatusCode.DeadlineExceeded or
_ = Task.Run(DiscoverEndpoints);
}

protected override ICredentialsProvider? CredentialsProvider { get; }

private async Task<Status> DiscoverEndpoints()
{
using var channel = _grpcChannelFactory.CreateChannel(Config.Endpoint);
using var channel = GrpcChannelFactory.CreateChannel(Config.Endpoint);

var client = new DiscoveryService.DiscoveryServiceClient(channel);

Expand Down Expand Up @@ -167,7 +141,7 @@ private async Task<Status> DiscoverEndpoints()
resultProto.Endpoints.Count, resultProto.SelfLocation, Config.SdkVersion
);

await _channelPool.RemoveChannels(
await ChannelPool.RemoveChannels(
_endpointPool.Reset(resultProto.Endpoints
.Select(endpointSettings => new EndpointSettings(
(int)endpointSettings.NodeId,
Expand Down
6 changes: 6 additions & 0 deletions src/Ydb.Sdk/src/DriverConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ public DriverConfig(
SdkVersion = $"ydb-dotnet-sdk/{versionStr}";
}

internal Grpc.Core.Metadata GetCallMetadata => new()
{
{ Metadata.RpcDatabaseHeader, Database },
{ Metadata.RpcSdkInfoHeader, SdkVersion }
};

private static string FormatEndpoint(string endpoint)
{
endpoint = endpoint.ToLower().Trim();
Expand Down
62 changes: 40 additions & 22 deletions src/Ydb.Sdk/src/IDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
using Grpc.Net.Client;
using Microsoft.Extensions.Logging;
using Ydb.Sdk.Auth;
using Ydb.Sdk.Pool;
using Ydb.Sdk.Services.Auth;

namespace Ydb.Sdk;

Expand Down Expand Up @@ -45,16 +47,35 @@ public interface IBidirectionalStream<in TRequest, out TResponse> : IDisposable

public abstract class BaseDriver : IDriver
{
private readonly ICredentialsProvider? _credentialsProvider;

protected readonly DriverConfig Config;
protected readonly ILogger Logger;

internal readonly GrpcChannelFactory GrpcChannelFactory;
internal readonly ChannelPool<GrpcChannel> ChannelPool;

protected int Disposed;

protected BaseDriver(DriverConfig config, ILoggerFactory loggerFactory, ILogger logger)
internal BaseDriver(
DriverConfig config,
ILoggerFactory loggerFactory,
ILogger logger
)
{
Config = config;
Logger = logger;
LoggerFactory = loggerFactory;

GrpcChannelFactory = new GrpcChannelFactory(LoggerFactory, Config);
ChannelPool = new ChannelPool<GrpcChannel>(LoggerFactory, GrpcChannelFactory);

_credentialsProvider = Config.User != null
? new CachedCredentialsProvider(
new StaticCredentialsAuthClient(Config, GrpcChannelFactory, LoggerFactory),
LoggerFactory
)
: Config.Credentials;
}

public async Task<TResponse> UnaryCall<TRequest, TResponse>(
Expand All @@ -64,7 +85,9 @@ public async Task<TResponse> UnaryCall<TRequest, TResponse>(
where TRequest : class
where TResponse : class
{
var (endpoint, channel) = GetChannel(settings.NodeId);
var endpoint = GetEndpoint(settings.NodeId);
var channel = ChannelPool.GetChannel(endpoint);

var callInvoker = channel.CreateCallInvoker();

Logger.LogTrace("Unary call, method: {MethodName}, endpoint: {Endpoint}", method.Name, endpoint);
Expand Down Expand Up @@ -97,7 +120,9 @@ public async ValueTask<ServerStream<TResponse>> ServerStreamCall<TRequest, TResp
where TRequest : class
where TResponse : class
{
var (endpoint, channel) = GetChannel(settings.NodeId);
var endpoint = GetEndpoint(settings.NodeId);
var channel = ChannelPool.GetChannel(endpoint);

var callInvoker = channel.CreateCallInvoker();

var call = callInvoker.AsyncServerStreamingCall(
Expand All @@ -115,7 +140,9 @@ public async ValueTask<IBidirectionalStream<TRequest, TResponse>> BidirectionalS
where TRequest : class
where TResponse : class
{
var (endpoint, channel) = GetChannel(settings.NodeId);
var endpoint = GetEndpoint(settings.NodeId);
var channel = ChannelPool.GetChannel(endpoint);

var callInvoker = channel.CreateCallInvoker();

var call = callInvoker.AsyncDuplexStreamingCall(
Expand All @@ -126,36 +153,29 @@ public async ValueTask<IBidirectionalStream<TRequest, TResponse>> BidirectionalS
return new BidirectionalStream<TRequest, TResponse>(
call,
e => { OnRpcError(endpoint, e); },
CredentialsProvider
_credentialsProvider
);
}

protected abstract (string, GrpcChannel) GetChannel(long nodeId);
protected abstract string GetEndpoint(long nodeId);

protected abstract void OnRpcError(string endpoint, RpcException e);

protected async ValueTask<CallOptions> GetCallOptions(GrpcRequestSettings settings)
{
var meta = new Grpc.Core.Metadata
{
{ Metadata.RpcDatabaseHeader, Config.Database },
{ Metadata.RpcSdkInfoHeader, Config.SdkVersion }
};
var meta = Config.GetCallMetadata;

if (CredentialsProvider != null)
if (_credentialsProvider != null)
{
meta.Add(Metadata.RpcAuthHeader, await CredentialsProvider.GetAuthInfoAsync());
meta.Add(Metadata.RpcAuthHeader, await _credentialsProvider.GetAuthInfoAsync());
}

if (settings.TraceId.Length > 0)
{
meta.Add(Metadata.RpcTraceIdHeader, settings.TraceId);
}

var options = new CallOptions(
headers: meta,
cancellationToken: settings.CancellationToken
);
var options = new CallOptions(headers: meta, cancellationToken: settings.CancellationToken);

if (settings.TransportTimeout != TimeSpan.Zero)
{
Expand All @@ -165,8 +185,6 @@ protected async ValueTask<CallOptions> GetCallOptions(GrpcRequestSettings settin
return options;
}

protected abstract ICredentialsProvider? CredentialsProvider { get; }

public ILoggerFactory LoggerFactory { get; }

public void Dispose() => DisposeAsync().AsTask().GetAwaiter().GetResult();
Expand All @@ -175,11 +193,11 @@ public async ValueTask DisposeAsync()
{
if (Interlocked.CompareExchange(ref Disposed, 1, 0) == 0)
{
await InternalDispose();
await ChannelPool.DisposeAsync();

GC.SuppressFinalize(this);
}
}

protected abstract ValueTask InternalDispose();
}

public sealed class ServerStream<TResponse> : IAsyncEnumerator<TResponse>, IAsyncEnumerable<TResponse>
Expand Down
4 changes: 2 additions & 2 deletions src/Ydb.Sdk/src/Pool/ChannelPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ internal class ChannelPool<T> : IAsyncDisposable where T : ChannelBase, IDisposa
private readonly ILogger<ChannelPool<T>> _logger;
private readonly IChannelFactory<T> _channelFactory;

public ChannelPool(ILogger<ChannelPool<T>> logger, IChannelFactory<T> channelFactory)
public ChannelPool(ILoggerFactory loggerFactory, IChannelFactory<T> channelFactory)
{
_logger = logger;
_logger = loggerFactory.CreateLogger<ChannelPool<T>>();
_channelFactory = channelFactory;
}

Expand Down
4 changes: 2 additions & 2 deletions src/Ydb.Sdk/src/Pool/EndpointPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ internal class EndpointPool
private Dictionary<long, string> _nodeIdToEndpoint = new();
private int _preferredEndpointCount;

internal EndpointPool(ILogger<EndpointPool> logger, IRandom? random = null)
internal EndpointPool(ILoggerFactory loggerFactory, IRandom? random = null)
{
_logger = logger;
_logger = loggerFactory.CreateLogger<EndpointPool>();
_random = random ?? ThreadLocalRandom.Instance;
}

Expand Down
Loading
Loading