diff --git a/src/Ydb.Sdk/src/Client/Response.cs b/src/Ydb.Sdk/src/Client/Response.cs index fd7b787b..c71cbf7d 100644 --- a/src/Ydb.Sdk/src/Client/Response.cs +++ b/src/Ydb.Sdk/src/Client/Response.cs @@ -63,11 +63,11 @@ public abstract class StreamResponse where TProtoResponse : class where TResponse : class { - private readonly Driver.StreamIterator _iterator; + private readonly Driver.ServerStream _iterator; private TResponse? _response; private bool _transportError; - internal StreamResponse(Driver.StreamIterator iterator) + internal StreamResponse(Driver.ServerStream iterator) { _iterator = iterator; } diff --git a/src/Ydb.Sdk/src/Driver.cs b/src/Ydb.Sdk/src/Driver.cs index d15d6b1e..b9dbc604 100644 --- a/src/Ydb.Sdk/src/Driver.cs +++ b/src/Ydb.Sdk/src/Driver.cs @@ -150,7 +150,7 @@ internal async Task UnaryCall( } } - internal StreamIterator StreamCall( + internal ServerStream ServerStreamCall( Method method, TRequest request, GrpcRequestSettings settings) @@ -160,22 +160,30 @@ internal StreamIterator StreamCall( var (endpoint, channel) = GetChannel(settings.NodeId); var callInvoker = channel.CreateCallInvoker(); - try - { - var call = callInvoker.AsyncServerStreamingCall( - method: method, - host: null, - options: GetCallOptions(settings, true), - request: request); + var call = callInvoker.AsyncServerStreamingCall( + method: method, + host: null, + options: GetCallOptions(settings, true), + request: request); - return new StreamIterator(call, () => { PessimizeEndpoint(endpoint); }); - } - catch (RpcException e) - { - PessimizeEndpoint(endpoint); + return new ServerStream(call, () => { PessimizeEndpoint(endpoint); }); + } - throw new TransportException(e); - } + internal BidirectionalStream BidirectionalStreamCall( + Method method, + GrpcRequestSettings settings) + where TRequest : class + where TResponse : class + { + var (endpoint, channel) = GetChannel(settings.NodeId); + var callInvoker = channel.CreateCallInvoker(); + + var call = callInvoker.AsyncDuplexStreamingCall( + method: method, + host: null, + options: GetCallOptions(settings, true)); + + return new BidirectionalStream(call, () => { PessimizeEndpoint(endpoint); }); } private (string, GrpcChannel) GetChannel(long nodeId) @@ -319,12 +327,12 @@ private CallOptions GetCallOptions(GrpcRequestSettings settings, bool streaming) return options; } - internal sealed class StreamIterator : IAsyncEnumerator, IAsyncEnumerable + internal sealed class ServerStream : IAsyncEnumerator, IAsyncEnumerable { private readonly AsyncServerStreamingCall _responseStream; private readonly Action _rpcErrorAction; - internal StreamIterator(AsyncServerStreamingCall responseStream, Action rpcErrorAction) + internal ServerStream(AsyncServerStreamingCall responseStream, Action rpcErrorAction) { _responseStream = responseStream; _rpcErrorAction = rpcErrorAction; @@ -359,6 +367,62 @@ public async ValueTask MoveNextAsync() } } + internal sealed class BidirectionalStream : IAsyncEnumerator, + IAsyncEnumerable + { + private readonly AsyncDuplexStreamingCall _bidirectionalStream; + private readonly Action _rpcErrorAction; + + public BidirectionalStream(AsyncDuplexStreamingCall bidirectionalStream, + Action rpcErrorAction) + { + _bidirectionalStream = bidirectionalStream; + _rpcErrorAction = rpcErrorAction; + } + + public async Task Write(TRequest request) + { + try + { + await _bidirectionalStream.RequestStream.WriteAsync(request); + } + catch (RpcException e) + { + _rpcErrorAction(); + + throw new TransportException(e); + } + } + + public ValueTask DisposeAsync() + { + _bidirectionalStream.Dispose(); + + return default; + } + + public async ValueTask MoveNextAsync() + { + try + { + return await _bidirectionalStream.ResponseStream.MoveNext(CancellationToken.None); + } + catch (RpcException e) + { + _rpcErrorAction(); + + throw new TransportException(e); + } + } + + public TResponse Current => _bidirectionalStream.ResponseStream.Current; + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = new()) + { + return this; + } + } + public class InitializationFailureException : Exception { internal InitializationFailureException(string message) : base(message) diff --git a/src/Ydb.Sdk/src/Services/Query/SessionPool.cs b/src/Ydb.Sdk/src/Services/Query/SessionPool.cs index 8615a4e8..10609fab 100644 --- a/src/Ydb.Sdk/src/Services/Query/SessionPool.cs +++ b/src/Ydb.Sdk/src/Services/Query/SessionPool.cs @@ -47,8 +47,8 @@ protected override async Task CreateSession() { try { - await using var stream = _driver.StreamCall(QueryService.AttachSessionMethod, new AttachSessionRequest - { SessionId = session.SessionId }, AttachSessionSettings); + await using var stream = _driver.ServerStreamCall(QueryService.AttachSessionMethod, + new AttachSessionRequest { SessionId = session.SessionId }, AttachSessionSettings); if (!await stream.MoveNextAsync()) { @@ -140,7 +140,7 @@ internal Session(Driver driver, SessionPool sessionPool, string session _driver = driver; } - internal Driver.StreamIterator ExecuteQuery( + internal Driver.ServerStream ExecuteQuery( string query, Dictionary? parameters, ExecuteQuerySettings? settings, @@ -161,7 +161,7 @@ internal Driver.StreamIterator ExecuteQuery( request.Parameters.Add(parameters.ToDictionary(p => p.Key, p => p.Value.GetProto())); - return _driver.StreamCall(QueryService.ExecuteQueryMethod, request, settings); + return _driver.ServerStreamCall(QueryService.ExecuteQueryMethod, request, settings); } internal async Task CommitTransaction(string txId, GrpcRequestSettings? settings = null) diff --git a/src/Ydb.Sdk/src/Services/Table/ExecuteScanQuery.cs b/src/Ydb.Sdk/src/Services/Table/ExecuteScanQuery.cs index 8a9a6046..f23cb909 100644 --- a/src/Ydb.Sdk/src/Services/Table/ExecuteScanQuery.cs +++ b/src/Ydb.Sdk/src/Services/Table/ExecuteScanQuery.cs @@ -34,7 +34,7 @@ internal static ResultData FromProto(ExecuteScanQueryPartialResult resultProto) public class ExecuteScanQueryStream : StreamResponse { - internal ExecuteScanQueryStream(Driver.StreamIterator iterator) + internal ExecuteScanQueryStream(Driver.ServerStream iterator) : base(iterator) { } @@ -75,7 +75,7 @@ public ExecuteScanQueryStream ExecuteScanQuery( request.Parameters.Add(parameters.ToDictionary(p => p.Key, p => p.Value.GetProto())); - var streamIterator = _driver.StreamCall( + var streamIterator = _driver.ServerStreamCall( method: TableService.StreamExecuteScanQueryMethod, request: request, settings: settings diff --git a/src/Ydb.Sdk/src/Services/Table/ReadTable.cs b/src/Ydb.Sdk/src/Services/Table/ReadTable.cs index 5e87007f..879e8edd 100644 --- a/src/Ydb.Sdk/src/Services/Table/ReadTable.cs +++ b/src/Ydb.Sdk/src/Services/Table/ReadTable.cs @@ -39,7 +39,7 @@ internal static ResultData FromProto(ReadTableResult resultProto) public class ReadTableStream : StreamResponse { - internal ReadTableStream(Driver.StreamIterator iterator) + internal ReadTableStream(Driver.ServerStream iterator) : base(iterator) { } @@ -74,7 +74,7 @@ public ReadTableStream ReadTable(string tablePath, ReadTableSettings? settings = Ordered = settings.Ordered }; - var streamIterator = _driver.StreamCall( + var streamIterator = _driver.ServerStreamCall( method: TableService.StreamReadTableMethod, request: request, settings: settings