Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
- Topic Reader & Writer: update auth token in bidirectional stream.

## v0.14.1
- Fixed bug: public key presented not for certificate signature.
- Fixed: YdbDataReader does not throw YdbException when CloseAsync is called for UPDATE/INSERT statements with no
Expand Down
15 changes: 13 additions & 2 deletions src/Ydb.Sdk/src/IDriver.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Grpc.Core;
using Grpc.Net.Client;
using Microsoft.Extensions.Logging;
using Ydb.Sdk.Auth;

namespace Ydb.Sdk;

Expand Down Expand Up @@ -36,6 +37,8 @@ public interface IBidirectionalStream<in TRequest, out TResponse> : IDisposable
public ValueTask<bool> MoveNextAsync();

public TResponse Current { get; }

public string? AuthToken { get; }
}

public abstract class BaseDriver : IDriver
Expand Down Expand Up @@ -118,7 +121,10 @@ public IBidirectionalStream<TRequest, TResponse> BidirectionalStreamCall<TReques
host: null,
options: GetCallOptions(settings));

return new BidirectionalStream<TRequest, TResponse>(call, e => { OnRpcError(endpoint, e); });
return new BidirectionalStream<TRequest, TResponse>(
call,
e => { OnRpcError(endpoint, e); },
Config.Credentials);
}

protected abstract (string, GrpcChannel) GetChannel(long nodeId);
Expand Down Expand Up @@ -218,13 +224,16 @@ internal class BidirectionalStream<TRequest, TResponse> : IBidirectionalStream<T
{
private readonly AsyncDuplexStreamingCall<TRequest, TResponse> _stream;
private readonly Action<RpcException> _rpcErrorAction;
private readonly ICredentialsProvider _credentialsProvider;

internal BidirectionalStream(
AsyncDuplexStreamingCall<TRequest, TResponse> stream,
Action<RpcException> rpcErrorAction)
Action<RpcException> rpcErrorAction,
ICredentialsProvider credentialsProvider)
{
_stream = stream;
_rpcErrorAction = rpcErrorAction;
_credentialsProvider = credentialsProvider;
}

public async Task Write(TRequest request)
Expand Down Expand Up @@ -257,6 +266,8 @@ public async ValueTask<bool> MoveNextAsync()

public TResponse Current => _stream.ResponseStream.Current;

public string? AuthToken => _credentialsProvider.GetAuthInfo();

public void Dispose()
{
_stream.Dispose();
Expand Down
13 changes: 12 additions & 1 deletion src/Ydb.Sdk/src/Services/Topic/Reader/Reader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ public async void RunProcessingTopic()
{
await foreach (var messageFromClient in _channelFromClientMessageSending.Reader.ReadAllAsync())
{
await Stream.Write(messageFromClient);
await SendMessage(messageFromClient);
}
}
catch (Driver.TransportException e)
Expand Down Expand Up @@ -539,4 +539,15 @@ await _channelWriter.WriteAsync(
}
}
}

protected override MessageFromClient GetSendUpdateTokenRequest(string token)
{
return new MessageFromClient
{
UpdateTokenRequest = new UpdateTokenRequest
{
Token = token
}
};
}
}
20 changes: 20 additions & 0 deletions src/Ydb.Sdk/src/Services/Topic/TopicSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ internal abstract class TopicSession<TFromClient, TFromServer> : IDisposable
protected readonly string SessionId;

private int _isActive = 1;
private string? _lastToken;

protected TopicSession(
IBidirectionalStream<TFromClient, TFromServer> stream,
Expand All @@ -22,6 +23,7 @@ protected TopicSession(
Logger = logger;
SessionId = sessionId;
_initialize = initialize;
_lastToken = stream.AuthToken;
}

public bool IsActive => Volatile.Read(ref _isActive) == 1;
Expand All @@ -40,8 +42,26 @@ protected async void ReconnectSession()
await _initialize();
}

protected async Task SendMessage(TFromClient fromClient)
{
var curAuthToken = Stream.AuthToken;

if (!string.Equals(_lastToken, curAuthToken) && curAuthToken != null)
{
var updateTokenRequest = GetSendUpdateTokenRequest(curAuthToken);

_lastToken = curAuthToken;

await Stream.Write(updateTokenRequest);
}

await Stream.Write(fromClient);
}

public void Dispose()
{
Stream.Dispose();
}

protected abstract TFromClient GetSendUpdateTokenRequest(string token);
}
15 changes: 13 additions & 2 deletions src/Ydb.Sdk/src/Services/Topic/Writer/Writer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ private async Task Initialize()
}
catch (OperationCanceledException)
{
_logger.LogWarning("Initialize writer is canceled because it has been disposed");
_logger.LogInformation("Initialize writer is canceled because it has been disposed");
}
}

Expand Down Expand Up @@ -449,7 +449,7 @@ public async Task Write(ConcurrentQueue<MessageSending> toSendBuffer)
}

Volatile.Write(ref _seqNum, currentSeqNum);
await Stream.Write(new MessageFromClient { WriteRequest = writeMessage });
await SendMessage(new MessageFromClient { WriteRequest = writeMessage });
}
catch (Driver.TransportException e)
{
Expand Down Expand Up @@ -531,4 +531,15 @@ Completing task on exception...
ReconnectSession();
}
}

protected override MessageFromClient GetSendUpdateTokenRequest(string token)
{
return new MessageFromClient
{
UpdateTokenRequest = new UpdateTokenRequest
{
Token = token
}
};
}
}
101 changes: 101 additions & 0 deletions src/Ydb.Sdk/tests/Topic/ReaderUnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,107 @@ public async Task ReadAsync_WhenFailDeserializer_ThrowReaderExceptionAndInvokeRe
(await Assert.ThrowsAsync<ReaderException>(() => reader.ReadAsync().AsTask())).Message);
}

/*
*
Performed invocations:

Mock<IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>:1> (stream):

IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.Write({ "initRequest": { "topicsReadSettings": [ { "path": "/topic" } ], "consumer": "Consumer Tester" } })
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.MoveNextAsync()
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.Current
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.Write({ "readRequest": { "bytesSize": "1000" } })
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.AuthToken
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.MoveNextAsync()
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.Current
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.MoveNextAsync()
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.AuthToken
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.Write({ "startPartitionSessionResponse": { "partitionSessionId": "1" } })
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.Current
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.MoveNextAsync()
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.AuthToken
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.Write({ "updateTokenRequest": { "token": "Token2" } })
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.Write({ "commitOffsetRequest": { "commitOffsets": [ { "partitionSessionId": "1", "offsets": [ { "end": "1" } ] } ] } })
IBidirectionalStream<StreamReadMessage.Types.FromClient, StreamReadMessage.Types.FromServer>.Current
*/
[Fact]
public async Task ReadAsync_WhenTokenIsUpdatedOneTime_SuccessUpdateToken()
{
_mockStream.SetupSequence(stream => stream.AuthToken)
.Returns("Token1")
.Returns("Token1")
.Returns("Token2")
.Returns("Token2");

var tcsMoveNext = new TaskCompletionSource<bool>();
var tcsCommitMessage = new TaskCompletionSource<bool>();

_mockStream.SetupSequence(stream => stream.Write(It.IsAny<FromClient>()))
.Returns(Task.CompletedTask)
.Returns(Task.CompletedTask)
.Returns(() =>
{
tcsMoveNext.SetResult(true);

return Task.CompletedTask;
})
.Returns(Task.CompletedTask)
.Returns(() =>
{
tcsCommitMessage.SetResult(true);

return Task.CompletedTask;
});

_mockStream.SetupSequence(stream => stream.MoveNextAsync())
.ReturnsAsync(true)
.ReturnsAsync(true)
.Returns(new ValueTask<bool>(tcsMoveNext.Task))
.Returns(new ValueTask<bool>(tcsCommitMessage.Task))
.Returns(new ValueTask<bool>(new TaskCompletionSource<bool>().Task));

_mockStream.SetupSequence(stream => stream.Current)
.Returns(InitResponseFromServer)
.Returns(StartPartitionSessionRequest())
.Returns(ReadResponse(0, BitConverter.GetBytes(100)))
.Returns(CommitOffsetResponse());

using var reader = new ReaderBuilder<int>(_mockIDriver.Object)
{
ConsumerName = "Consumer Tester",
MemoryUsageMaxBytes = 1000,
SubscribeSettings = { new SubscribeSettings("/topic") }
}.Build();

var message = await reader.ReadAsync();
await message.CommitAsync();
Assert.Equal(100, message.Data);

_mockStream.Verify(stream => stream.Write(It.IsAny<FromClient>()), Times.Exactly(5));
_mockStream.Verify(stream => stream.MoveNextAsync(), Times.Between(4, 5, Range.Inclusive));
_mockStream.Verify(stream => stream.Current, Times.Exactly(4));

_mockStream.Verify(stream => stream.Write(It.Is<FromClient>(msg =>
msg.InitRequest != null &&
msg.InitRequest.Consumer == "Consumer Tester" &&
msg.InitRequest.TopicsReadSettings[0].Path == "/topic")));
_mockStream.Verify(stream => stream.Write(It.Is<FromClient>(msg =>
msg.ReadRequest != null &&
msg.ReadRequest.BytesSize == 1000)));
_mockStream.Verify(stream => stream.Write(It.Is<FromClient>(msg =>
msg.StartPartitionSessionResponse != null &&
msg.StartPartitionSessionResponse.PartitionSessionId == 1)));
_mockStream.Verify(stream => stream.Write(It.Is<FromClient>(msg =>
msg.ReadRequest != null)));
_mockStream.Verify(stream => stream.Write(It.Is<FromClient>(msg =>
msg.CommitOffsetRequest != null &&
msg.CommitOffsetRequest.CommitOffsets[0].PartitionSessionId == 1 &&
msg.CommitOffsetRequest.CommitOffsets[0].Offsets[0].End == 1)));
_mockStream.Verify(stream => stream.Write(It.Is<FromClient>(msg =>
msg.UpdateTokenRequest != null &&
msg.UpdateTokenRequest.Token == "Token2")));
}

private class FailDeserializer : IDeserializer<int>
{
public int Deserialize(byte[] data)
Expand Down
Loading
Loading