Skip to content

Commit 2a98b0b

Browse files
type safe generic
1 parent 246edcb commit 2a98b0b

File tree

8 files changed

+65
-66
lines changed

8 files changed

+65
-66
lines changed

src/Ydb.Sdk/src/Ado/PoolManager.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace Ydb.Sdk.Ado;
77
internal static class PoolManager
88
{
99
private static readonly SemaphoreSlim SemaphoreSlim = new(1); // async mutex
10-
private static readonly ConcurrentDictionary<string, PoolingSessionSource> Pools = new();
10+
private static readonly ConcurrentDictionary<string, ISessionSource> Pools = new();
1111

1212
internal static async Task<ISession> GetSession(
1313
YdbConnectionStringBuilder settings,
@@ -28,7 +28,7 @@ CancellationToken cancellationToken
2828
return await pool.OpenSession(cancellationToken);
2929
}
3030

31-
var newSessionPool = new PoolingSessionSource(
31+
var newSessionPool = new PoolingSessionSource<PoolingSession>(
3232
new PoolingSessionFactory(
3333
await settings.BuildDriver(),
3434
settings,
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
namespace Ydb.Sdk.Ado.Session;
22

3-
internal interface ISessionSource<TSession> where TSession : ISession
3+
internal interface ISessionSource
44
{
5-
ValueTask<TSession> OpenSession(CancellationToken cancellationToken);
6-
7-
void Return(TSession session);
5+
ValueTask<ISession> OpenSession(CancellationToken cancellationToken);
86
}

src/Ydb.Sdk/src/Ado/Session/PoolingSession.cs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
// This file contains session pooling algorithms adapted from Npgsql
2-
// Original source: https://github.com/npgsql/npgsql
3-
// Copyright (c) 2002-2025, Npgsql
4-
// Licence https://github.com/npgsql/npgsql?tab=PostgreSQL-1-ov-file
5-
61
using Microsoft.Extensions.Logging;
72
using Ydb.Query;
83
using Ydb.Query.V1;
@@ -13,7 +8,7 @@
138

149
namespace Ydb.Sdk.Ado.Session;
1510

16-
internal class PoolingSession : PoolingSessionBase
11+
internal class PoolingSession : PoolingSessionBase<PoolingSession>
1712
{
1813
private const string SessionBalancer = "session-balancer";
1914

@@ -36,7 +31,7 @@ internal class PoolingSession : PoolingSessionBase
3631

3732
internal PoolingSession(
3833
IDriver driver,
39-
PoolingSessionSource poolingSessionSource,
34+
PoolingSessionSource<PoolingSession> poolingSessionSource,
4035
bool disableServerBalancer,
4136
ILogger<PoolingSession> logger
4237
) : base(poolingSessionSource)

src/Ydb.Sdk/src/Ado/Session/PoolingSessionFactory.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
namespace Ydb.Sdk.Ado.Session;
44

5-
internal class PoolingSessionFactory : IPoolingSessionFactory
5+
internal class PoolingSessionFactory : IPoolingSessionFactory<PoolingSession>
66
{
77
private readonly IDriver _driver;
88
private readonly bool _disableServerBalancer;
@@ -15,6 +15,6 @@ public PoolingSessionFactory(IDriver driver, YdbConnectionStringBuilder settings
1515
_logger = loggerFactory.CreateLogger<PoolingSession>();
1616
}
1717

18-
public PoolingSessionBase NewSession(PoolingSessionSource source) =>
19-
new PoolingSession(_driver, source, _disableServerBalancer, _logger);
18+
public PoolingSession NewSession(PoolingSessionSource<PoolingSession> source) =>
19+
new(_driver, source, _disableServerBalancer, _logger);
2020
}

src/Ydb.Sdk/src/Ado/Session/PoolingSessionSource.cs

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,23 @@
66

77
namespace Ydb.Sdk.Ado.Session;
88

9-
internal sealed class PoolingSessionSource : ISessionSource<PoolingSessionBase>
9+
internal sealed class PoolingSessionSource<T> : ISessionSource where T : PoolingSessionBase<T>
1010
{
11-
private readonly ConcurrentStack<PoolingSessionBase> _idleSessions = new();
12-
private readonly ConcurrentQueue<TaskCompletionSource<PoolingSessionBase?>> _waiters = new();
13-
14-
private readonly IPoolingSessionFactory _sessionFactory;
11+
private readonly ConcurrentStack<T> _idleSessions = new();
12+
private readonly ConcurrentQueue<TaskCompletionSource<T?>> _waiters = new();
1513

14+
private readonly IPoolingSessionFactory<T> _sessionFactory;
1615
private readonly int _minSessionSize;
1716
private readonly int _maxSessionSize;
18-
19-
private readonly PoolingSessionBase?[] _sessions;
20-
17+
private readonly T?[] _sessions;
2118
private readonly int _createSessionTimeout;
2219
private readonly TimeSpan _sessionIdleTimeout;
2320
private readonly Timer _cleanerTimer;
2421

2522
private volatile int _numSessions;
2623

2724
public PoolingSessionSource(
28-
IPoolingSessionFactory sessionFactory,
25+
IPoolingSessionFactory<T> sessionFactory,
2926
YdbConnectionStringBuilder settings
3027
)
3128
{
@@ -39,19 +36,19 @@ YdbConnectionStringBuilder settings
3936
$"Connection can't have 'Max Session Pool' {_maxSessionSize} under 'Min Session Pool' {_minSessionSize}");
4037
}
4138

42-
_sessions = new PoolingSessionBase?[_maxSessionSize];
39+
_sessions = new T?[_maxSessionSize];
4340
_createSessionTimeout = settings.CreateSessionTimeout;
4441
_sessionIdleTimeout = TimeSpan.FromSeconds(settings.SessionIdleTimeout);
4542
_cleanerTimer = new Timer(CleanIdleSessions, this, _sessionIdleTimeout, _sessionIdleTimeout);
4643
}
4744

48-
public ValueTask<PoolingSessionBase> OpenSession(CancellationToken cancellationToken = default) =>
45+
public ValueTask<ISession> OpenSession(CancellationToken cancellationToken = default) =>
4946
TryGetIdleSession(out var session)
50-
? new ValueTask<PoolingSessionBase>(session)
47+
? new ValueTask<ISession>(session)
5148
: RentAsync(cancellationToken);
5249

5350
[MethodImpl(MethodImplOptions.AggressiveInlining)]
54-
private bool TryGetIdleSession([NotNullWhen(true)] out PoolingSessionBase? session)
51+
private bool TryGetIdleSession([NotNullWhen(true)] out T? session)
5552
{
5653
while (_idleSessions.TryPop(out session))
5754
{
@@ -65,7 +62,7 @@ private bool TryGetIdleSession([NotNullWhen(true)] out PoolingSessionBase? sessi
6562
}
6663

6764
[MethodImpl(MethodImplOptions.AggressiveInlining)]
68-
private bool CheckIdleSession([NotNullWhen(true)] PoolingSessionBase? session)
65+
private bool CheckIdleSession([NotNullWhen(true)] T? session)
6966
{
7067
if (session == null || session.State == PoolingSessionState.Clean)
7168
{
@@ -82,7 +79,7 @@ private bool CheckIdleSession([NotNullWhen(true)] PoolingSessionBase? session)
8279
return session.CompareAndSet(PoolingSessionState.In, PoolingSessionState.Out);
8380
}
8481

85-
private async ValueTask<PoolingSessionBase> RentAsync(CancellationToken cancellationToken)
82+
private async ValueTask<ISession> RentAsync(CancellationToken cancellationToken)
8683
{
8784
using var ctsGetSession = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
8885
if (_createSessionTimeout > 0)
@@ -96,8 +93,7 @@ private async ValueTask<PoolingSessionBase> RentAsync(CancellationToken cancella
9693

9794
while (true)
9895
{
99-
var waiterTcs =
100-
new TaskCompletionSource<PoolingSessionBase?>(TaskCreationOptions.RunContinuationsAsynchronously);
96+
var waiterTcs = new TaskCompletionSource<T?>(TaskCreationOptions.RunContinuationsAsynchronously);
10197
_waiters.Enqueue(waiterTcs);
10298
await using var _ = finalToken.Register(() => waiterTcs.TrySetCanceled(), useSynchronizationContext: false);
10399
session = await waiterTcs.Task.ConfigureAwait(false);
@@ -111,9 +107,8 @@ private async ValueTask<PoolingSessionBase> RentAsync(CancellationToken cancella
111107
}
112108
}
113109

114-
private async ValueTask<PoolingSessionBase?> OpenNewSession(CancellationToken cancellationToken)
110+
private async ValueTask<T?> OpenNewSession(CancellationToken cancellationToken)
115111
{
116-
// As long as we're under max capacity, attempt to increase the session count and open a new session.
117112
for (var numSessions = _numSessions; numSessions < _maxSessionSize; numSessions = _numSessions)
118113
{
119114
if (Interlocked.CompareExchange(ref _numSessions, numSessions + 1, numSessions) != numSessions)
@@ -135,12 +130,8 @@ private async ValueTask<PoolingSessionBase> RentAsync(CancellationToken cancella
135130
}
136131
catch
137132
{
138-
// RPC open failed, decrement the open and busy counter back down.
139133
Interlocked.Decrement(ref _numSessions);
140134

141-
// In case there's a waiting attempt on the waiters queue, we write a null to the idle connector channel
142-
// to wake it up, so it will try opening (and probably throw immediately)
143-
// Statement order is important since we have synchronous completions on the channel.
144135
WakeUpWaiter();
145136

146137
throw;
@@ -156,7 +147,7 @@ private void WakeUpWaiter()
156147
waiter.TrySetResult(null); // wake up waiter!
157148
}
158149

159-
public void Return(PoolingSessionBase session)
150+
public void Return(T session)
160151
{
161152
if (session.IsBroken)
162153
{
@@ -181,7 +172,7 @@ public void Return(PoolingSessionBase session)
181172
WakeUpWaiter();
182173
}
183174

184-
private void CloseSession(PoolingSessionBase session)
175+
private void CloseSession(T session)
185176
{
186177
var i = 0;
187178
for (; i < _maxSessionSize; i++)
@@ -202,7 +193,7 @@ private void CloseSession(PoolingSessionBase session)
202193

203194
private static void CleanIdleSessions(object? state)
204195
{
205-
var pool = (PoolingSessionSource)state!;
196+
var pool = (PoolingSessionSource<T>)state!;
206197
var now = DateTime.Now;
207198

208199
for (var i = 0; i < pool._maxSessionSize; i++)
@@ -222,9 +213,9 @@ private static void CleanIdleSessions(object? state)
222213
}
223214
}
224215

225-
internal interface IPoolingSessionFactory
216+
internal interface IPoolingSessionFactory<T> where T : PoolingSessionBase<T>
226217
{
227-
PoolingSessionBase NewSession(PoolingSessionSource source);
218+
T NewSession(PoolingSessionSource<T> source);
228219
}
229220

230221
internal enum PoolingSessionState
@@ -234,13 +225,13 @@ internal enum PoolingSessionState
234225
Clean
235226
}
236227

237-
internal abstract class PoolingSessionBase : ISession
228+
internal abstract class PoolingSessionBase<T> : ISession where T : PoolingSessionBase<T>
238229
{
239-
private readonly PoolingSessionSource _source;
230+
private readonly PoolingSessionSource<T> _source;
240231

241232
private int _state = (int)PoolingSessionState.In;
242233

243-
protected PoolingSessionBase(PoolingSessionSource source)
234+
protected PoolingSessionBase(PoolingSessionSource<T> source)
244235
{
245236
_source = source;
246237
}
@@ -272,5 +263,5 @@ public abstract ValueTask<IServerStream<ExecuteQueryResponsePart>> ExecuteQuery(
272263

273264
public abstract void OnNotSuccessStatusCode(StatusCode code);
274265

275-
public void Close() => _source.Return(this);
266+
public void Close() => _source.Return((T)this);
276267
}

src/Ydb.Sdk/test/Ydb.Sdk.Ado.Benchmarks/SessionSourceBenchmark.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Ydb.Sdk.Ado.Benchmarks;
99
[ThreadingDiagnoser]
1010
public class SessionSourceBenchmark
1111
{
12-
private PoolingSessionSource _poolingSessionSource = null!;
12+
private PoolingSessionSource<MockPoolingSession> _poolingSessionSource = null!;
1313
private const int SessionPoolSize = 50;
1414
private const int ConcurrentTasks = 20;
1515

@@ -18,7 +18,7 @@ public void Setup()
1818
{
1919
var settings = new YdbConnectionStringBuilder { MaxSessionPool = SessionPoolSize };
2020

21-
_poolingSessionSource = new PoolingSessionSource(new MockSessionFactory(), settings);
21+
_poolingSessionSource = new PoolingSessionSource<MockPoolingSession>(new MockSessionFactory(), settings);
2222
}
2323

2424
[Benchmark]
@@ -85,12 +85,13 @@ public async Task SessionReuse_Pattern()
8585
}
8686
}
8787

88-
internal class MockSessionFactory : IPoolingSessionFactory
88+
internal class MockSessionFactory : IPoolingSessionFactory<MockPoolingSession>
8989
{
90-
public PoolingSessionBase NewSession(PoolingSessionSource source) => new MockPoolingSession(source);
90+
public MockPoolingSession NewSession(PoolingSessionSource<MockPoolingSession> source) => new(source);
9191
}
9292

93-
internal class MockPoolingSession(PoolingSessionSource source) : PoolingSessionBase(source)
93+
internal class MockPoolingSession(PoolingSessionSource<MockPoolingSession> source)
94+
: PoolingSessionBase<MockPoolingSession>(source)
9495
{
9596
public override IDriver Driver => null!;
9697
public override bool IsBroken => false;

src/Ydb.Sdk/test/Ydb.Sdk.Ado.Tests/Session/PoolingSessionSourceMockTests.cs

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,44 @@ public class PoolingSessionSourceMockTests
99
{
1010
[Fact]
1111
public void MinSessionPool_bigger_than_MaxSessionPool_throws() => Assert.Throws<ArgumentException>(() =>
12-
new PoolingSessionSource(new MockPoolingSessionFactory(),
12+
new PoolingSessionSource<MockPoolingSession>(new MockPoolingSessionFactory(),
1313
new YdbConnectionStringBuilder { MaxSessionPool = 1, MinSessionPool = 2 })
1414
);
1515

1616
[Fact]
1717
public async Task Reuse_Session_Before_Creating_new()
1818
{
19-
var sessionSource = new PoolingSessionSource(new MockPoolingSessionFactory(), new YdbConnectionStringBuilder());
20-
var session = (MockPoolingSession)await sessionSource.OpenSession();
21-
var sessionId = session.SessionId;
19+
var sessionSource =
20+
new PoolingSessionSource<MockPoolingSession>(new MockPoolingSessionFactory(),
21+
new YdbConnectionStringBuilder());
22+
var session = await sessionSource.OpenSession();
23+
var sessionId = session.SessionId();
2224
session.Close();
23-
session = (MockPoolingSession)await sessionSource.OpenSession();
24-
Assert.Equal(sessionId, session.SessionId);
25+
session = await sessionSource.OpenSession();
26+
Assert.Equal(sessionId, session.SessionId());
2527
}
28+
29+
[Fact]
30+
public async Task Creating_Session_Throw_Exception()
31+
{
32+
}
33+
}
34+
35+
internal static class ISessionExtension
36+
{
37+
internal static string SessionId(this ISession session) => ((MockPoolingSession)session).SessionId;
2638
}
2739

28-
internal class MockPoolingSessionFactory : IPoolingSessionFactory
40+
internal class MockPoolingSessionFactory : IPoolingSessionFactory<MockPoolingSession>
2941
{
3042
private int _sessionNum;
3143

32-
public PoolingSessionBase NewSession(PoolingSessionSource source) =>
33-
new MockPoolingSession(source, Interlocked.Increment(ref _sessionNum));
44+
public MockPoolingSession NewSession(PoolingSessionSource<MockPoolingSession> source) =>
45+
new(source, Interlocked.Increment(ref _sessionNum));
3446
}
3547

36-
internal class MockPoolingSession(PoolingSessionSource source, int sessionNum) : PoolingSessionBase(source)
48+
internal class MockPoolingSession(PoolingSessionSource<MockPoolingSession> source, int sessionNum)
49+
: PoolingSessionBase<MockPoolingSession>(source)
3750
{
3851
public string SessionId => $"session_{sessionNum}";
3952
public override IDriver Driver => null!;
@@ -44,7 +57,8 @@ internal class MockPoolingSession(PoolingSessionSource source, int sessionNum) :
4457

4558
public override ValueTask<IServerStream<ExecuteQueryResponsePart>> ExecuteQuery(
4659
string query,
47-
Dictionary<string, YdbValue> parameters, GrpcRequestSettings settings,
60+
Dictionary<string, YdbValue> parameters,
61+
GrpcRequestSettings settings,
4862
TransactionControl? txControl
4963
) => throw new NotImplementedException();
5064

src/Ydb.Sdk/test/Ydb.Sdk.Ado.Tests/Session/PoolingSessionTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public class PoolingSessionTests
1717
private readonly Mock<IDriver> _mockIDriver;
1818
private readonly Mock<IServerStream<SessionState>> _mockAttachStream = new(MockBehavior.Strict);
1919
private readonly PoolingSessionFactory _poolingSessionFactory;
20-
private readonly PoolingSessionSource _poolingSessionSource;
20+
private readonly PoolingSessionSource<PoolingSession> _poolingSessionSource;
2121

2222
public PoolingSessionTests()
2323
{
@@ -32,7 +32,7 @@ public PoolingSessionTests()
3232
).ReturnsAsync(_mockAttachStream.Object);
3333
_mockAttachStream.Setup(stream => stream.Dispose());
3434
_poolingSessionFactory = new PoolingSessionFactory(_mockIDriver.Object, settings, TestUtils.LoggerFactory);
35-
_poolingSessionSource = new PoolingSessionSource(_poolingSessionFactory, settings);
35+
_poolingSessionSource = new PoolingSessionSource<PoolingSession>(_poolingSessionFactory, settings);
3636
}
3737

3838
[Theory]

0 commit comments

Comments
 (0)