Skip to content

Commit 06b0e42

Browse files
fix bug with credentials
1 parent 097b5b8 commit 06b0e42

File tree

9 files changed

+66
-106
lines changed

9 files changed

+66
-106
lines changed

src/Ydb.Sdk/src/Driver.cs

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
using System.Collections.Immutable;
22
using Grpc.Core;
3-
using Grpc.Net.Client;
43
using Microsoft.Extensions.Logging;
54
using Microsoft.Extensions.Logging.Abstractions;
65
using Ydb.Discovery;
76
using Ydb.Discovery.V1;
8-
using Ydb.Sdk.Auth;
97
using Ydb.Sdk.Pool;
10-
using Ydb.Sdk.Services.Auth;
118

129
namespace Ydb.Sdk;
1310

1411
public sealed class Driver : BaseDriver
1512
{
1613
private const int AttemptDiscovery = 10;
1714

18-
private readonly GrpcChannelFactory _grpcChannelFactory;
1915
private readonly EndpointPool _endpointPool;
20-
private readonly ChannelPool<GrpcChannel> _channelPool;
2116

2217
internal string Database => Config.Database;
2318

@@ -26,19 +21,7 @@ public Driver(DriverConfig config, ILoggerFactory? loggerFactory = null)
2621
(loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<Driver>()
2722
)
2823
{
29-
_grpcChannelFactory = new GrpcChannelFactory(LoggerFactory, config);
30-
_endpointPool = new EndpointPool(LoggerFactory.CreateLogger<EndpointPool>());
31-
_channelPool = new ChannelPool<GrpcChannel>(
32-
LoggerFactory.CreateLogger<ChannelPool<GrpcChannel>>(),
33-
_grpcChannelFactory
34-
);
35-
36-
CredentialsProvider = Config.User != null
37-
? new CachedCredentialsProvider(
38-
new StaticCredentialsAuthClient(config, _grpcChannelFactory, LoggerFactory),
39-
LoggerFactory
40-
)
41-
: Config.Credentials;
24+
_endpointPool = new EndpointPool(LoggerFactory);
4225
}
4326

4427
public static async Task<Driver> CreateInitialized(DriverConfig config, ILoggerFactory? loggerFactory = null)
@@ -48,8 +31,6 @@ public static async Task<Driver> CreateInitialized(DriverConfig config, ILoggerF
4831
return driver;
4932
}
5033

51-
protected override ValueTask InternalDispose() => _channelPool.DisposeAsync();
52-
5334
public async Task Initialize()
5435
{
5536
Logger.LogInformation("Started initial endpoint discovery");
@@ -78,18 +59,13 @@ public async Task Initialize()
7859
}
7960
}
8061

81-
await Task.Delay(TimeSpan.FromMilliseconds(i * 200)); // await 0 ms, 200 ms, 400ms, .. 1.8 sec
62+
await Task.Delay(TimeSpan.FromMilliseconds(i * 200)); // await 0 ms, 200 ms, 400ms, ... 1.8 sec
8263
}
8364

8465
throw new InitializationFailureException("Error during initial endpoint discovery");
8566
}
8667

87-
protected override (string, GrpcChannel) GetChannel(long nodeId)
88-
{
89-
var endpoint = _endpointPool.GetEndpoint(nodeId);
90-
91-
return (endpoint, _channelPool.GetChannel(endpoint));
92-
}
68+
protected override string GetEndpoint(long nodeId) => _endpointPool.GetEndpoint(nodeId);
9369

9470
protected override void OnRpcError(string endpoint, RpcException e)
9571
{
@@ -114,11 +90,9 @@ Grpc.Core.StatusCode.DeadlineExceeded or
11490
_ = Task.Run(DiscoverEndpoints);
11591
}
11692

117-
protected override ICredentialsProvider? CredentialsProvider { get; }
118-
11993
private async Task<Status> DiscoverEndpoints()
12094
{
121-
using var channel = _grpcChannelFactory.CreateChannel(Config.Endpoint);
95+
using var channel = GrpcChannelFactory.CreateChannel(Config.Endpoint);
12296

12397
var client = new DiscoveryService.DiscoveryServiceClient(channel);
12498

@@ -167,7 +141,7 @@ private async Task<Status> DiscoverEndpoints()
167141
resultProto.Endpoints.Count, resultProto.SelfLocation, Config.SdkVersion
168142
);
169143

170-
await _channelPool.RemoveChannels(
144+
await ChannelPool.RemoveChannels(
171145
_endpointPool.Reset(resultProto.Endpoints
172146
.Select(endpointSettings => new EndpointSettings(
173147
(int)endpointSettings.NodeId,

src/Ydb.Sdk/src/DriverConfig.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ public DriverConfig(
5555
SdkVersion = $"ydb-dotnet-sdk/{versionStr}";
5656
}
5757

58+
internal Grpc.Core.Metadata GetCallMetadata => new()
59+
{
60+
{ Metadata.RpcDatabaseHeader, Database },
61+
{ Metadata.RpcSdkInfoHeader, SdkVersion }
62+
};
63+
5864
private static string FormatEndpoint(string endpoint)
5965
{
6066
endpoint = endpoint.ToLower().Trim();

src/Ydb.Sdk/src/IDriver.cs

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
using Grpc.Net.Client;
33
using Microsoft.Extensions.Logging;
44
using Ydb.Sdk.Auth;
5+
using Ydb.Sdk.Pool;
6+
using Ydb.Sdk.Services.Auth;
57

68
namespace Ydb.Sdk;
79

@@ -45,16 +47,35 @@ public interface IBidirectionalStream<in TRequest, out TResponse> : IDisposable
4547

4648
public abstract class BaseDriver : IDriver
4749
{
50+
private readonly ICredentialsProvider? _credentialsProvider;
51+
4852
protected readonly DriverConfig Config;
4953
protected readonly ILogger Logger;
54+
55+
internal readonly GrpcChannelFactory GrpcChannelFactory;
56+
internal readonly ChannelPool<GrpcChannel> ChannelPool;
5057

5158
protected int Disposed;
5259

53-
protected BaseDriver(DriverConfig config, ILoggerFactory loggerFactory, ILogger logger)
60+
internal BaseDriver(
61+
DriverConfig config,
62+
ILoggerFactory loggerFactory,
63+
ILogger logger
64+
)
5465
{
5566
Config = config;
5667
Logger = logger;
5768
LoggerFactory = loggerFactory;
69+
70+
GrpcChannelFactory = new GrpcChannelFactory(LoggerFactory, Config);
71+
ChannelPool = new ChannelPool<GrpcChannel>(LoggerFactory, GrpcChannelFactory);
72+
73+
_credentialsProvider = Config.User != null
74+
? new CachedCredentialsProvider(
75+
new StaticCredentialsAuthClient(Config, GrpcChannelFactory, LoggerFactory),
76+
LoggerFactory
77+
)
78+
: Config.Credentials;
5879
}
5980

6081
public async Task<TResponse> UnaryCall<TRequest, TResponse>(
@@ -64,7 +85,9 @@ public async Task<TResponse> UnaryCall<TRequest, TResponse>(
6485
where TRequest : class
6586
where TResponse : class
6687
{
67-
var (endpoint, channel) = GetChannel(settings.NodeId);
88+
var endpoint = GetEndpoint(settings.NodeId);
89+
var channel = ChannelPool.GetChannel(endpoint);
90+
6891
var callInvoker = channel.CreateCallInvoker();
6992

7093
Logger.LogTrace("Unary call, method: {MethodName}, endpoint: {Endpoint}", method.Name, endpoint);
@@ -97,7 +120,9 @@ public async ValueTask<ServerStream<TResponse>> ServerStreamCall<TRequest, TResp
97120
where TRequest : class
98121
where TResponse : class
99122
{
100-
var (endpoint, channel) = GetChannel(settings.NodeId);
123+
var endpoint = GetEndpoint(settings.NodeId);
124+
var channel = ChannelPool.GetChannel(endpoint);
125+
101126
var callInvoker = channel.CreateCallInvoker();
102127

103128
var call = callInvoker.AsyncServerStreamingCall(
@@ -115,7 +140,9 @@ public async ValueTask<IBidirectionalStream<TRequest, TResponse>> BidirectionalS
115140
where TRequest : class
116141
where TResponse : class
117142
{
118-
var (endpoint, channel) = GetChannel(settings.NodeId);
143+
var endpoint = GetEndpoint(settings.NodeId);
144+
var channel = ChannelPool.GetChannel(endpoint);
145+
119146
var callInvoker = channel.CreateCallInvoker();
120147

121148
var call = callInvoker.AsyncDuplexStreamingCall(
@@ -126,36 +153,29 @@ public async ValueTask<IBidirectionalStream<TRequest, TResponse>> BidirectionalS
126153
return new BidirectionalStream<TRequest, TResponse>(
127154
call,
128155
e => { OnRpcError(endpoint, e); },
129-
CredentialsProvider
156+
_credentialsProvider
130157
);
131158
}
132159

133-
protected abstract (string, GrpcChannel) GetChannel(long nodeId);
160+
protected abstract string GetEndpoint(long nodeId);
134161

135162
protected abstract void OnRpcError(string endpoint, RpcException e);
136163

137164
protected async ValueTask<CallOptions> GetCallOptions(GrpcRequestSettings settings)
138165
{
139-
var meta = new Grpc.Core.Metadata
140-
{
141-
{ Metadata.RpcDatabaseHeader, Config.Database },
142-
{ Metadata.RpcSdkInfoHeader, Config.SdkVersion }
143-
};
166+
var meta = Config.GetCallMetadata;
144167

145-
if (CredentialsProvider != null)
168+
if (_credentialsProvider != null)
146169
{
147-
meta.Add(Metadata.RpcAuthHeader, await CredentialsProvider.GetAuthInfoAsync());
170+
meta.Add(Metadata.RpcAuthHeader, await _credentialsProvider.GetAuthInfoAsync());
148171
}
149172

150173
if (settings.TraceId.Length > 0)
151174
{
152175
meta.Add(Metadata.RpcTraceIdHeader, settings.TraceId);
153176
}
154177

155-
var options = new CallOptions(
156-
headers: meta,
157-
cancellationToken: settings.CancellationToken
158-
);
178+
var options = new CallOptions(headers: meta, cancellationToken: settings.CancellationToken);
159179

160180
if (settings.TransportTimeout != TimeSpan.Zero)
161181
{
@@ -165,8 +185,6 @@ protected async ValueTask<CallOptions> GetCallOptions(GrpcRequestSettings settin
165185
return options;
166186
}
167187

168-
protected abstract ICredentialsProvider? CredentialsProvider { get; }
169-
170188
public ILoggerFactory LoggerFactory { get; }
171189

172190
public void Dispose() => DisposeAsync().AsTask().GetAwaiter().GetResult();
@@ -175,11 +193,11 @@ public async ValueTask DisposeAsync()
175193
{
176194
if (Interlocked.CompareExchange(ref Disposed, 1, 0) == 0)
177195
{
178-
await InternalDispose();
196+
await ChannelPool.DisposeAsync();
197+
198+
GC.SuppressFinalize(this);
179199
}
180200
}
181-
182-
protected abstract ValueTask InternalDispose();
183201
}
184202

185203
public sealed class ServerStream<TResponse> : IAsyncEnumerator<TResponse>, IAsyncEnumerable<TResponse>

src/Ydb.Sdk/src/Pool/ChannelPool.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ internal class ChannelPool<T> : IAsyncDisposable where T : ChannelBase, IDisposa
1515
private readonly ILogger<ChannelPool<T>> _logger;
1616
private readonly IChannelFactory<T> _channelFactory;
1717

18-
public ChannelPool(ILogger<ChannelPool<T>> logger, IChannelFactory<T> channelFactory)
18+
public ChannelPool(ILoggerFactory loggerFactory, IChannelFactory<T> channelFactory)
1919
{
20-
_logger = logger;
20+
_logger = loggerFactory.CreateLogger<ChannelPool<T>>();
2121
_channelFactory = channelFactory;
2222
}
2323

src/Ydb.Sdk/src/Pool/EndpointPool.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ internal class EndpointPool
1616
private Dictionary<long, string> _nodeIdToEndpoint = new();
1717
private int _preferredEndpointCount;
1818

19-
internal EndpointPool(ILogger<EndpointPool> logger, IRandom? random = null)
19+
internal EndpointPool(ILoggerFactory loggerFactory, IRandom? random = null)
2020
{
21-
_logger = logger;
21+
_logger = loggerFactory.CreateLogger<EndpointPool>();
2222
_random = random ?? ThreadLocalRandom.Instance;
2323
}
2424

src/Ydb.Sdk/src/Services/Auth/StaticCredentialsAuthClient.cs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
using System.IdentityModel.Tokens.Jwt;
2+
using Grpc.Core;
23
using Microsoft.Extensions.Logging;
34
using Ydb.Auth;
45
using Ydb.Auth.V1;
56
using Ydb.Sdk.Auth;
67
using Ydb.Sdk.Client;
78
using Ydb.Sdk.Pool;
89
using Ydb.Sdk.Services.Operations;
9-
using Ydb.Sdk.Transport;
1010

1111
namespace Ydb.Sdk.Services.Auth;
1212

1313
internal class StaticCredentialsAuthClient : IAuthClient
1414
{
1515
private readonly DriverConfig _config;
1616
private readonly GrpcChannelFactory _grpcChannelFactory;
17-
private readonly ILoggerFactory _loggerFactory;
1817
private readonly ILogger<StaticCredentialsAuthClient> _logger;
1918

2019
private readonly RetrySettings _retrySettings = new(5);
@@ -27,7 +26,6 @@ ILoggerFactory loggerFactory
2726
{
2827
_config = config;
2928
_grpcChannelFactory = grpcChannelFactory;
30-
_loggerFactory = loggerFactory;
3129
_logger = loggerFactory.CreateLogger<StaticCredentialsAuthClient>();
3230
}
3331

@@ -76,13 +74,10 @@ private async Task<LoginResponse> Login()
7674

7775
try
7876
{
79-
await using var transport = new DirectGrpcChannelDriver(_config, _grpcChannelFactory, _loggerFactory);
77+
using var channel = _grpcChannelFactory.CreateChannel(_config.Endpoint);
8078

81-
var response = await transport.UnaryCall(
82-
method: AuthService.LoginMethod,
83-
request: request,
84-
settings: new GrpcRequestSettings()
85-
);
79+
var response = await new AuthService.AuthServiceClient(channel)
80+
.LoginAsync(request, new CallOptions(_config.GetCallMetadata));
8681

8782
var status = response.Operation.TryUnpack(out LoginResult? resultProto);
8883

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,22 @@
11
using Grpc.Core;
2-
using Grpc.Net.Client;
32
using Microsoft.Extensions.Logging;
4-
using Ydb.Sdk.Auth;
5-
using Ydb.Sdk.Pool;
63

74
namespace Ydb.Sdk.Transport;
85

96
public class DirectGrpcChannelDriver : BaseDriver
107
{
11-
private readonly GrpcChannel _channel;
12-
13-
internal DirectGrpcChannelDriver(
14-
DriverConfig driverConfig,
15-
GrpcChannelFactory grpcChannelFactory,
16-
ILoggerFactory loggerFactory
17-
) : base(
18-
new DriverConfig(
19-
endpoint: driverConfig.Endpoint,
20-
database: driverConfig.Database,
21-
customServerCertificates: driverConfig.CustomServerCertificates
22-
), loggerFactory, loggerFactory.CreateLogger<DirectGrpcChannelDriver>())
23-
{
24-
_channel = grpcChannelFactory.CreateChannel(Config.Endpoint);
25-
}
26-
278
public DirectGrpcChannelDriver(DriverConfig driverConfig, ILoggerFactory loggerFactory) :
28-
this(driverConfig, new GrpcChannelFactory(loggerFactory, driverConfig), loggerFactory)
9+
base(driverConfig, loggerFactory, loggerFactory.CreateLogger<DirectGrpcChannelDriver>())
2910
{
3011
}
3112

32-
protected override (string, GrpcChannel) GetChannel(long nodeId) => (Config.Endpoint, _channel);
13+
protected override string GetEndpoint(long nodeId) => Config.Endpoint;
3314

3415
protected override void OnRpcError(string endpoint, RpcException e)
3516
{
3617
var status = e.Status;
37-
if (e.Status.StatusCode != Grpc.Core.StatusCode.OK)
38-
{
39-
Logger.LogWarning("gRPC error {StatusCode}[{Detail}] on fixed channel {Endpoint}",
40-
status.StatusCode, status.Detail, endpoint);
41-
}
42-
}
43-
44-
protected override ICredentialsProvider? CredentialsProvider => null;
45-
46-
protected override async ValueTask InternalDispose()
47-
{
48-
await _channel.ShutdownAsync();
4918

50-
_channel.Dispose();
19+
Logger.LogWarning("gRPC error {StatusCode}[{Detail}] on fixed channel {Endpoint}",
20+
status.StatusCode, status.Detail, endpoint);
5121
}
5222
}

0 commit comments

Comments
 (0)