Skip to content

Commit 1547f5f

Browse files
fix race condition on volatile _fetchTokenTask
1 parent e77e872 commit 1547f5f

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

src/Ydb.Sdk/src/Auth/CachedCredentialsProvider.cs

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ public async ValueTask<ITokenState> Validate(DateTime now)
111111

112112
private class SyncState : ITokenState
113113
{
114-
private readonly CachedCredentialsProvider _cachedCredentialsProvider;
114+
private readonly TaskCompletionSource<TokenResponse> _fetchTokenResponseTcs = new();
115115

116-
private volatile Task<TokenResponse> _fetchTokenTask = null!;
116+
private readonly CachedCredentialsProvider _cachedCredentialsProvider;
117117

118118
public SyncState(CachedCredentialsProvider cachedCredentialsProvider)
119119
{
@@ -127,7 +127,7 @@ public async ValueTask<ITokenState> Validate(DateTime now)
127127
{
128128
try
129129
{
130-
var tokenResponse = await _fetchTokenTask;
130+
var tokenResponse = await _fetchTokenResponseTcs.Task;
131131

132132
_cachedCredentialsProvider.Logger.LogDebug(
133133
"Successfully fetched token. ExpiredAt: {ExpiredAt}, RefreshAt: {RefreshAt}",
@@ -145,14 +145,26 @@ public async ValueTask<ITokenState> Validate(DateTime now)
145145
}
146146
}
147147

148-
public void Init() => _fetchTokenTask = _cachedCredentialsProvider.FetchToken();
148+
public async void Init()
149+
{
150+
try
151+
{
152+
var tokenResponse = await _cachedCredentialsProvider.FetchToken();
153+
154+
_fetchTokenResponseTcs.SetResult(tokenResponse);
155+
}
156+
catch (Exception e)
157+
{
158+
_fetchTokenResponseTcs.SetException(e);
159+
}
160+
}
149161
}
150162

151163
private class BackgroundState : ITokenState
152164
{
153-
private readonly CachedCredentialsProvider _cachedCredentialsProvider;
165+
private readonly TaskCompletionSource<TokenResponse> _fetchTokenResponseTcs = new();
154166

155-
private volatile Task<TokenResponse> _fetchTokenTask = null!;
167+
private readonly CachedCredentialsProvider _cachedCredentialsProvider;
156168

157169
public BackgroundState(TokenResponse tokenResponse,
158170
CachedCredentialsProvider cachedCredentialsProvider)
@@ -165,11 +177,13 @@ public BackgroundState(TokenResponse tokenResponse,
165177

166178
public async ValueTask<ITokenState> Validate(DateTime now)
167179
{
168-
if (_fetchTokenTask.IsCanceled || _fetchTokenTask.IsFaulted)
180+
var fetchTokenTask = _fetchTokenResponseTcs.Task;
181+
182+
if (fetchTokenTask.IsCanceled || fetchTokenTask.IsFaulted)
169183
{
170184
_cachedCredentialsProvider.Logger.LogWarning(
171185
"Fetching token task failed. Status: {Status}, Retrying login...",
172-
_fetchTokenTask.IsCanceled ? "Canceled" : "Faulted"
186+
fetchTokenTask.IsCanceled ? "Canceled" : "Faulted"
173187
);
174188

175189
return now >= TokenResponse.ExpiredAt
@@ -180,10 +194,10 @@ public async ValueTask<ITokenState> Validate(DateTime now)
180194
.UpdateState(this, new BackgroundState(TokenResponse, _cachedCredentialsProvider));
181195
}
182196

183-
if (_fetchTokenTask.IsCompleted)
197+
if (fetchTokenTask.IsCompleted)
184198
{
185199
return _cachedCredentialsProvider
186-
.UpdateState(this, new ActiveState(await _fetchTokenTask, _cachedCredentialsProvider));
200+
.UpdateState(this, new ActiveState(await fetchTokenTask, _cachedCredentialsProvider));
187201
}
188202

189203
if (now < TokenResponse.ExpiredAt)
@@ -193,7 +207,7 @@ public async ValueTask<ITokenState> Validate(DateTime now)
193207

194208
try
195209
{
196-
var tokenResponse = await _fetchTokenTask;
210+
var tokenResponse = await fetchTokenTask;
197211

198212
_cachedCredentialsProvider.Logger.LogDebug(
199213
"Successfully fetched token. ExpiredAt: {ExpiredAt}, RefreshAt: {RefreshAt}",
@@ -207,12 +221,23 @@ public async ValueTask<ITokenState> Validate(DateTime now)
207221
{
208222
_cachedCredentialsProvider.Logger.LogCritical(e, "Error on authentication token update");
209223

210-
return _cachedCredentialsProvider.UpdateState(this,
211-
new ErrorState(e, _cachedCredentialsProvider));
224+
return _cachedCredentialsProvider.UpdateState(this, new ErrorState(e, _cachedCredentialsProvider));
212225
}
213226
}
214227

215-
public void Init() => _fetchTokenTask = _cachedCredentialsProvider.FetchToken();
228+
public async void Init()
229+
{
230+
try
231+
{
232+
var tokenResponse = await _cachedCredentialsProvider.FetchToken();
233+
234+
_fetchTokenResponseTcs.SetResult(tokenResponse);
235+
}
236+
catch (Exception e)
237+
{
238+
_fetchTokenResponseTcs.SetException(e);
239+
}
240+
}
216241
}
217242

218243
private class ErrorState : ITokenState

0 commit comments

Comments
 (0)