Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Ydb.Sdk/src/Client/Response.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ public abstract class StreamResponse<TProtoResponse, TResponse>
where TProtoResponse : class
where TResponse : class
{
private readonly Driver.StreamIterator<TProtoResponse> _iterator;
private readonly Driver.ServerStream<TProtoResponse> _iterator;
private TResponse? _response;
private bool _transportError;

internal StreamResponse(Driver.StreamIterator<TProtoResponse> iterator)
internal StreamResponse(Driver.ServerStream<TProtoResponse> iterator)
{
_iterator = iterator;
}
Expand Down
98 changes: 81 additions & 17 deletions src/Ydb.Sdk/src/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ internal async Task<TResponse> UnaryCall<TRequest, TResponse>(
}
}

internal StreamIterator<TResponse> StreamCall<TRequest, TResponse>(
internal ServerStream<TResponse> ServerStreamCall<TRequest, TResponse>(
Method<TRequest, TResponse> method,
TRequest request,
GrpcRequestSettings settings)
Expand All @@ -160,22 +160,30 @@ internal StreamIterator<TResponse> StreamCall<TRequest, TResponse>(
var (endpoint, channel) = GetChannel(settings.NodeId);
var callInvoker = channel.CreateCallInvoker();

try
{
var call = callInvoker.AsyncServerStreamingCall(
method: method,
host: null,
options: GetCallOptions(settings, true),
request: request);
var call = callInvoker.AsyncServerStreamingCall(
method: method,
host: null,
options: GetCallOptions(settings, true),
request: request);

return new StreamIterator<TResponse>(call, () => { PessimizeEndpoint(endpoint); });
}
catch (RpcException e)
{
PessimizeEndpoint(endpoint);
return new ServerStream<TResponse>(call, () => { PessimizeEndpoint(endpoint); });
}

throw new TransportException(e);
}
internal BidirectionalStream<TRequest, TResponse> BidirectionalStreamCall<TRequest, TResponse>(
Method<TRequest, TResponse> method,
GrpcRequestSettings settings)
where TRequest : class
where TResponse : class
{
var (endpoint, channel) = GetChannel(settings.NodeId);
var callInvoker = channel.CreateCallInvoker();

var call = callInvoker.AsyncDuplexStreamingCall(
method: method,
host: null,
options: GetCallOptions(settings, true));

return new BidirectionalStream<TRequest, TResponse>(call, () => { PessimizeEndpoint(endpoint); });
}

private (string, GrpcChannel) GetChannel(long nodeId)
Expand Down Expand Up @@ -319,12 +327,12 @@ private CallOptions GetCallOptions(GrpcRequestSettings settings, bool streaming)
return options;
}

internal sealed class StreamIterator<TResponse> : IAsyncEnumerator<TResponse>, IAsyncEnumerable<TResponse>
internal sealed class ServerStream<TResponse> : IAsyncEnumerator<TResponse>, IAsyncEnumerable<TResponse>
{
private readonly AsyncServerStreamingCall<TResponse> _responseStream;
private readonly Action _rpcErrorAction;

internal StreamIterator(AsyncServerStreamingCall<TResponse> responseStream, Action rpcErrorAction)
internal ServerStream(AsyncServerStreamingCall<TResponse> responseStream, Action rpcErrorAction)
{
_responseStream = responseStream;
_rpcErrorAction = rpcErrorAction;
Expand Down Expand Up @@ -359,6 +367,62 @@ public async ValueTask<bool> MoveNextAsync()
}
}

internal sealed class BidirectionalStream<TRequest, TResponse> : IAsyncEnumerator<TResponse>,
IAsyncEnumerable<TResponse>
{
private readonly AsyncDuplexStreamingCall<TRequest, TResponse> _bidirectionalStream;
private readonly Action _rpcErrorAction;

public BidirectionalStream(AsyncDuplexStreamingCall<TRequest, TResponse> bidirectionalStream,
Action rpcErrorAction)
{
_bidirectionalStream = bidirectionalStream;
_rpcErrorAction = rpcErrorAction;
}

public async Task Write(TRequest request)
{
try
{
await _bidirectionalStream.RequestStream.WriteAsync(request);
}
catch (RpcException e)
{
_rpcErrorAction();

throw new TransportException(e);
}
}

public ValueTask DisposeAsync()
{
_bidirectionalStream.Dispose();

return default;
}

public async ValueTask<bool> MoveNextAsync()
{
try
{
return await _bidirectionalStream.ResponseStream.MoveNext(CancellationToken.None);
}
catch (RpcException e)
{
_rpcErrorAction();

throw new TransportException(e);
}
}

public TResponse Current => _bidirectionalStream.ResponseStream.Current;

public IAsyncEnumerator<TResponse> GetAsyncEnumerator(CancellationToken cancellationToken = new())
{
return this;
}
}

public class InitializationFailureException : Exception
{
internal InitializationFailureException(string message) : base(message)
Expand Down
8 changes: 4 additions & 4 deletions src/Ydb.Sdk/src/Services/Query/SessionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ protected override async Task<Session> CreateSession()
{
try
{
await using var stream = _driver.StreamCall(QueryService.AttachSessionMethod, new AttachSessionRequest
{ SessionId = session.SessionId }, AttachSessionSettings);
await using var stream = _driver.ServerStreamCall(QueryService.AttachSessionMethod,
new AttachSessionRequest { SessionId = session.SessionId }, AttachSessionSettings);

if (!await stream.MoveNextAsync())
{
Expand Down Expand Up @@ -140,7 +140,7 @@ internal Session(Driver driver, SessionPool<Session> sessionPool, string session
_driver = driver;
}

internal Driver.StreamIterator<ExecuteQueryResponsePart> ExecuteQuery(
internal Driver.ServerStream<ExecuteQueryResponsePart> ExecuteQuery(
string query,
Dictionary<string, YdbValue>? parameters,
ExecuteQuerySettings? settings,
Expand All @@ -161,7 +161,7 @@ internal Driver.StreamIterator<ExecuteQueryResponsePart> ExecuteQuery(

request.Parameters.Add(parameters.ToDictionary(p => p.Key, p => p.Value.GetProto()));

return _driver.StreamCall(QueryService.ExecuteQueryMethod, request, settings);
return _driver.ServerStreamCall(QueryService.ExecuteQueryMethod, request, settings);
}

internal async Task<Status> CommitTransaction(string txId, GrpcRequestSettings? settings = null)
Expand Down
4 changes: 2 additions & 2 deletions src/Ydb.Sdk/src/Services/Table/ExecuteScanQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ internal static ResultData FromProto(ExecuteScanQueryPartialResult resultProto)

public class ExecuteScanQueryStream : StreamResponse<ExecuteScanQueryPartialResponse, ExecuteScanQueryPart>
{
internal ExecuteScanQueryStream(Driver.StreamIterator<ExecuteScanQueryPartialResponse> iterator)
internal ExecuteScanQueryStream(Driver.ServerStream<ExecuteScanQueryPartialResponse> iterator)
: base(iterator)
{
}
Expand Down Expand Up @@ -75,7 +75,7 @@ public ExecuteScanQueryStream ExecuteScanQuery(

request.Parameters.Add(parameters.ToDictionary(p => p.Key, p => p.Value.GetProto()));

var streamIterator = _driver.StreamCall(
var streamIterator = _driver.ServerStreamCall(
method: TableService.StreamExecuteScanQueryMethod,
request: request,
settings: settings
Expand Down
4 changes: 2 additions & 2 deletions src/Ydb.Sdk/src/Services/Table/ReadTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ internal static ResultData FromProto(ReadTableResult resultProto)

public class ReadTableStream : StreamResponse<ReadTableResponse, ReadTablePart>
{
internal ReadTableStream(Driver.StreamIterator<ReadTableResponse> iterator)
internal ReadTableStream(Driver.ServerStream<ReadTableResponse> iterator)
: base(iterator)
{
}
Expand Down Expand Up @@ -74,7 +74,7 @@ public ReadTableStream ReadTable(string tablePath, ReadTableSettings? settings =
Ordered = settings.Ordered
};

var streamIterator = _driver.StreamCall(
var streamIterator = _driver.ServerStreamCall(
method: TableService.StreamReadTableMethod,
request: request,
settings: settings
Expand Down
Loading