Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Ydb.Sdk/src/Client/Response.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ public abstract class StreamResponse<TProtoResponse, TResponse>
where TProtoResponse : class
where TResponse : class
{
private readonly Driver.StreamIterator<TProtoResponse> _iterator;
private readonly Driver.ServerStream<TProtoResponse> _iterator;
private TResponse? _response;
private bool _transportError;

internal StreamResponse(Driver.StreamIterator<TProtoResponse> iterator)
internal StreamResponse(Driver.ServerStream<TProtoResponse> iterator)
{
_iterator = iterator;
}
Expand Down
98 changes: 81 additions & 17 deletions src/Ydb.Sdk/src/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ internal async Task<TResponse> UnaryCall<TRequest, TResponse>(
}
}

internal StreamIterator<TResponse> StreamCall<TRequest, TResponse>(
internal ServerStream<TResponse> ServerStreamCall<TRequest, TResponse>(
Method<TRequest, TResponse> method,
TRequest request,
GrpcRequestSettings settings)
Expand All @@ -160,22 +160,30 @@ internal StreamIterator<TResponse> StreamCall<TRequest, TResponse>(
var (endpoint, channel) = GetChannel(settings.NodeId);
var callInvoker = channel.CreateCallInvoker();

try
{
var call = callInvoker.AsyncServerStreamingCall(
method: method,
host: null,
options: GetCallOptions(settings, true),
request: request);
var call = callInvoker.AsyncServerStreamingCall(
method: method,
host: null,
options: GetCallOptions(settings, true),
request: request);

return new StreamIterator<TResponse>(call, () => { PessimizeEndpoint(endpoint); });
}
catch (RpcException e)
{
PessimizeEndpoint(endpoint);
return new ServerStream<TResponse>(call, () => { PessimizeEndpoint(endpoint); });
}

throw new TransportException(e);
}
internal BidirectionalStream<TRequest, TResponse> BidirectionalStreamCall<TRequest, TResponse>(
Method<TRequest, TResponse> method,
GrpcRequestSettings settings)
where TRequest : class
where TResponse : class
{
var (endpoint, channel) = GetChannel(settings.NodeId);
var callInvoker = channel.CreateCallInvoker();

var call = callInvoker.AsyncDuplexStreamingCall(
method: method,
host: null,
options: GetCallOptions(settings, true));

return new BidirectionalStream<TRequest, TResponse>(call, () => { PessimizeEndpoint(endpoint); });
}

private (string, GrpcChannel) GetChannel(long nodeId)
Expand Down Expand Up @@ -319,12 +327,12 @@ private CallOptions GetCallOptions(GrpcRequestSettings settings, bool streaming)
return options;
}

internal sealed class StreamIterator<TResponse> : IAsyncEnumerator<TResponse>, IAsyncEnumerable<TResponse>
internal sealed class ServerStream<TResponse> : IAsyncEnumerator<TResponse>, IAsyncEnumerable<TResponse>
{
private readonly AsyncServerStreamingCall<TResponse> _responseStream;
private readonly Action _rpcErrorAction;

internal StreamIterator(AsyncServerStreamingCall<TResponse> responseStream, Action rpcErrorAction)
internal ServerStream(AsyncServerStreamingCall<TResponse> responseStream, Action rpcErrorAction)
{
_responseStream = responseStream;
_rpcErrorAction = rpcErrorAction;
Expand Down Expand Up @@ -359,6 +367,62 @@ public async ValueTask<bool> MoveNextAsync()
}
}

internal sealed class BidirectionalStream<TRequest, TResponse> : IAsyncEnumerator<TResponse>,
IAsyncEnumerable<TResponse>
{
private readonly AsyncDuplexStreamingCall<TRequest, TResponse> _bidirectionalStream;
private readonly Action _rpcErrorAction;

public BidirectionalStream(AsyncDuplexStreamingCall<TRequest, TResponse> bidirectionalStream,
Action rpcErrorAction)
{
_bidirectionalStream = bidirectionalStream;
_rpcErrorAction = rpcErrorAction;
}

public async Task Write(TRequest request)
{
try
{
await _bidirectionalStream.RequestStream.WriteAsync(request);
}
catch (RpcException e)
{
_rpcErrorAction();

throw new TransportException(e);
}
}

public ValueTask DisposeAsync()
{
_bidirectionalStream.Dispose();

return default;
}

public async ValueTask<bool> MoveNextAsync()
{
try
{
return await _bidirectionalStream.ResponseStream.MoveNext(CancellationToken.None);
}
catch (RpcException e)
{
_rpcErrorAction();

throw new TransportException(e);
}
}

public TResponse Current => _bidirectionalStream.ResponseStream.Current;

public IAsyncEnumerator<TResponse> GetAsyncEnumerator(CancellationToken cancellationToken = new())
{
return this;
}
}

public class InitializationFailureException : Exception
{
internal InitializationFailureException(string message) : base(message)
Expand Down
6 changes: 3 additions & 3 deletions src/Ydb.Sdk/src/Services/Query/SessionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protected override async Task<Session> CreateSession()
{
try
{
await using var stream = _driver.StreamCall(QueryService.AttachSessionMethod, new AttachSessionRequest
await using var stream = _driver.ServerStreamCall(QueryService.AttachSessionMethod, new AttachSessionRequest
{ SessionId = session.SessionId }, AttachSessionSettings);

if (!await stream.MoveNextAsync())
Expand Down Expand Up @@ -140,7 +140,7 @@ internal Session(Driver driver, SessionPool<Session> sessionPool, string session
_driver = driver;
}

internal Driver.StreamIterator<ExecuteQueryResponsePart> ExecuteQuery(
internal Driver.ServerStream<ExecuteQueryResponsePart> ExecuteQuery(
string query,
Dictionary<string, YdbValue>? parameters,
ExecuteQuerySettings? settings,
Expand All @@ -161,7 +161,7 @@ internal Driver.StreamIterator<ExecuteQueryResponsePart> ExecuteQuery(

request.Parameters.Add(parameters.ToDictionary(p => p.Key, p => p.Value.GetProto()));

return _driver.StreamCall(QueryService.ExecuteQueryMethod, request, settings);
return _driver.ServerStreamCall(QueryService.ExecuteQueryMethod, request, settings);
}

internal async Task<Status> CommitTransaction(string txId, GrpcRequestSettings? settings = null)
Expand Down
4 changes: 2 additions & 2 deletions src/Ydb.Sdk/src/Services/Table/ExecuteScanQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ internal static ResultData FromProto(ExecuteScanQueryPartialResult resultProto)

public class ExecuteScanQueryStream : StreamResponse<ExecuteScanQueryPartialResponse, ExecuteScanQueryPart>
{
internal ExecuteScanQueryStream(Driver.StreamIterator<ExecuteScanQueryPartialResponse> iterator)
internal ExecuteScanQueryStream(Driver.ServerStream<ExecuteScanQueryPartialResponse> iterator)
: base(iterator)
{
}
Expand Down Expand Up @@ -75,7 +75,7 @@ public ExecuteScanQueryStream ExecuteScanQuery(

request.Parameters.Add(parameters.ToDictionary(p => p.Key, p => p.Value.GetProto()));

var streamIterator = _driver.StreamCall(
var streamIterator = _driver.ServerStreamCall(
method: TableService.StreamExecuteScanQueryMethod,
request: request,
settings: settings
Expand Down
4 changes: 2 additions & 2 deletions src/Ydb.Sdk/src/Services/Table/ReadTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ internal static ResultData FromProto(ReadTableResult resultProto)

public class ReadTableStream : StreamResponse<ReadTableResponse, ReadTablePart>
{
internal ReadTableStream(Driver.StreamIterator<ReadTableResponse> iterator)
internal ReadTableStream(Driver.ServerStream<ReadTableResponse> iterator)
: base(iterator)
{
}
Expand Down Expand Up @@ -74,7 +74,7 @@ public ReadTableStream ReadTable(string tablePath, ReadTableSettings? settings =
Ordered = settings.Ordered
};

var streamIterator = _driver.StreamCall(
var streamIterator = _driver.ServerStreamCall(
method: TableService.StreamReadTableMethod,
request: request,
settings: settings
Expand Down
Loading