1+ using System . IdentityModel . Tokens . Jwt ;
2+ using Microsoft . Extensions . Logging ;
3+ using Microsoft . Extensions . Logging . Abstractions ;
4+ using Ydb . Sdk . Services . Auth ;
5+
6+ namespace Ydb . Sdk . Auth ;
7+
8+ public class StaticCredentialsProvider : ICredentialsProvider , IUseDriverConfig
9+ {
10+ private readonly ILogger _logger ;
11+
12+ private readonly string _user ;
13+ private readonly string ? _password ;
14+
15+ private Driver ? _driver ;
16+
17+ public int MaxRetries = 5 ;
18+
19+ private readonly object _lock = new ( ) ;
20+
21+ private volatile TokenData ? _token ;
22+ private volatile Task ? _refreshTask ;
23+
24+ public float RefreshRatio = .1f ;
25+
26+ /// <summary>
27+ ///
28+ /// </summary>
29+ /// <param name="user">User of the database</param>
30+ /// <param name="password">Password of the user. If user has no password use null </param>
31+ /// <param name="loggerFactory"></param>
32+ public StaticCredentialsProvider ( string user , string ? password , ILoggerFactory ? loggerFactory = null )
33+ {
34+ _user = user ;
35+ _password = password ;
36+ loggerFactory ??= NullLoggerFactory . Instance ;
37+ _logger = loggerFactory . CreateLogger < StaticCredentialsProvider > ( ) ;
38+ }
39+
40+ private async Task Initialize ( )
41+ {
42+ _token = await ReceiveToken ( ) ;
43+ }
44+
45+ public string GetAuthInfo ( )
46+ {
47+ var token = _token ;
48+
49+ if ( token is null )
50+ {
51+ lock ( _lock )
52+ {
53+ if ( _token is not null ) return _token . Token ;
54+ _logger . LogWarning (
55+ "Blocking for initial token acquirement, please use explicit Initialize async method." ) ;
56+
57+ Initialize ( ) . Wait ( ) ;
58+
59+ return _token ! . Token ;
60+ }
61+ }
62+
63+ if ( token . IsExpired ( ) )
64+ {
65+ lock ( _lock )
66+ {
67+ if ( ! _token ! . IsExpired ( ) ) return _token . Token ;
68+ _logger . LogWarning ( "Blocking on expired token." ) ;
69+
70+ _token = ReceiveToken ( ) . Result ;
71+
72+ return _token . Token ;
73+ }
74+ }
75+
76+ if ( ! token . IsRefreshNeeded ( ) || _refreshTask is not null ) return _token ! . Token ;
77+ lock ( _lock )
78+ {
79+ if ( ! _token ! . IsRefreshNeeded ( ) || _refreshTask is not null ) return _token ! . Token ;
80+ _logger . LogInformation ( "Refreshing token." ) ;
81+
82+ _refreshTask = Task . Run ( RefreshToken ) ;
83+ }
84+
85+ return _token ! . Token ;
86+ }
87+
88+ private async Task RefreshToken ( )
89+ {
90+ var token = await ReceiveToken ( ) ;
91+
92+ lock ( _lock )
93+ {
94+ _token = token ;
95+ _refreshTask = null ;
96+ }
97+ }
98+
99+ private async Task < TokenData > ReceiveToken ( )
100+ {
101+ var retryAttempt = 0 ;
102+ while ( true )
103+ {
104+ try
105+ {
106+ _logger . LogTrace ( $ "Attempting to receive token, attempt: { retryAttempt } ") ;
107+
108+ var token = await FetchToken ( ) ;
109+
110+ _logger . LogInformation ( $ "Received token, expires at: { token . ExpiresAt } ") ;
111+
112+ return token ;
113+ }
114+ catch ( InvalidCredentialsException e )
115+ {
116+ _logger . LogWarning ( $ "Invalid credentials, { e } ") ;
117+ throw ;
118+ }
119+ catch ( Exception e )
120+ {
121+ _logger . LogDebug ( $ "Failed to fetch token, { e } ") ;
122+
123+ if ( retryAttempt >= MaxRetries )
124+ {
125+ _logger . LogWarning ( $ "Can't fetch token, { e } ") ;
126+ throw ;
127+ }
128+
129+ await Task . Delay ( TimeSpan . FromSeconds ( Math . Pow ( 2 , retryAttempt ) ) ) ;
130+ _logger . LogInformation ( $ "Failed to fetch token, attempt { retryAttempt } ") ;
131+ ++ retryAttempt ;
132+ }
133+ }
134+ }
135+
136+ private async Task < TokenData > FetchToken ( )
137+ {
138+ if ( _driver is null )
139+ {
140+ _logger . LogError ( "Driver in for static auth not provided" ) ;
141+ throw new NullReferenceException ( ) ;
142+ }
143+
144+ var client = new AuthClient ( _driver ) ;
145+ var loginResponse = await client . Login ( _user , _password ) ;
146+ if ( loginResponse . Status . StatusCode == StatusCode . Unauthorized )
147+ {
148+ throw new InvalidCredentialsException ( Issue . IssuesToString ( loginResponse . Status . Issues ) ) ;
149+ }
150+
151+ loginResponse . Status . EnsureSuccess ( ) ;
152+ var token = loginResponse . Result . Token ;
153+ var jwt = new JwtSecurityToken ( token ) ;
154+ return new TokenData ( token , jwt . ValidTo , RefreshRatio ) ;
155+ }
156+
157+ public async Task ProvideConfig ( DriverConfig driverConfig )
158+ {
159+ _driver = await Driver . CreateInitialized (
160+ new DriverConfig (
161+ driverConfig . Endpoint ,
162+ driverConfig . Database ,
163+ new AnonymousProvider ( ) ,
164+ driverConfig . DefaultTransportTimeout ,
165+ driverConfig . DefaultStreamingTransportTimeout ,
166+ driverConfig . CustomServerCertificate ) ) ;
167+
168+ await Initialize ( ) ;
169+ }
170+
171+ private class TokenData
172+ {
173+ public TokenData ( string token , DateTime expiresAt , float refreshInterval )
174+ {
175+ var now = DateTime . UtcNow ;
176+
177+ Token = token ;
178+ ExpiresAt = expiresAt ;
179+
180+ if ( expiresAt <= now )
181+ {
182+ RefreshAt = expiresAt ;
183+ }
184+ else
185+ {
186+ RefreshAt = now + ( expiresAt - now ) * refreshInterval ;
187+
188+ if ( RefreshAt < now )
189+ {
190+ RefreshAt = expiresAt ;
191+ }
192+ }
193+ }
194+
195+ public string Token { get ; }
196+ public DateTime ExpiresAt { get ; }
197+
198+ private DateTime RefreshAt { get ; }
199+
200+ public bool IsExpired ( )
201+ {
202+ return DateTime . UtcNow >= ExpiresAt ;
203+ }
204+
205+ public bool IsRefreshNeeded ( )
206+ {
207+ return DateTime . UtcNow >= RefreshAt ;
208+ }
209+ }
210+ }
0 commit comments