Skip to content

Commit 1bb034f

Browse files
authored
fix(realtime_client): Prevent sending expired tokens (#1095)
* fix: prevent sending expired tokens * widen the constraint for crypto dev dependencies on realtime * fix: properly handle exception on supabase-client for realtime set auth * fix: handle realtime token exception on SupabaseClient * await all setAuth calls * pass custom access token as params * properly parse JWT within realtime client
1 parent c971786 commit 1bb034f

File tree

5 files changed

+155
-24
lines changed

5 files changed

+155
-24
lines changed

packages/realtime_client/lib/src/realtime_channel.dart

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,11 @@ class RealtimeChannel {
150150

151151
joinPush.receive(
152152
'ok',
153-
(response) {
153+
(response) async {
154154
final serverPostgresFilters = response['postgres_changes'];
155-
if (socket.accessToken != null) socket.setAuth(socket.accessToken);
155+
if (socket.accessToken != null) {
156+
await socket.setAuth(socket.accessToken);
157+
}
156158

157159
if (serverPostgresFilters == null) {
158160
if (callback != null) {

packages/realtime_client/lib/src/realtime_client.dart

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class RealtimeCloseEvent {
5454
}
5555

5656
class RealtimeClient {
57+
// This is named `accessTokenValue` in supabase-js
5758
String? accessToken;
5859
List<RealtimeChannel> channels = [];
5960
final String endPoint;
@@ -89,6 +90,8 @@ class RealtimeClient {
8990
};
9091
int longpollerTimeout = 20000;
9192
SocketStates? connState;
93+
// This is called `accessToken` in realtime-js
94+
Future<String> Function()? customAccessToken;
9295

9396
/// Initializes the Socket
9497
///
@@ -129,6 +132,7 @@ class RealtimeClient {
129132
this.longpollerTimeout = 20000,
130133
RealtimeLogLevel? logLevel,
131134
this.httpClient,
135+
this.customAccessToken,
132136
}) : endPoint = Uri.parse('$endPoint/${Transports.websocket}')
133137
.replace(
134138
queryParameters:
@@ -403,15 +407,43 @@ class RealtimeClient {
403407
/// Sets the JWT access token used for channel subscription authorization and Realtime RLS.
404408
///
405409
/// `token` A JWT strings.
406-
void setAuth(String? token) {
407-
accessToken = token;
410+
Future<void> setAuth(String? token) async {
411+
final tokenToSend =
412+
token ?? (await customAccessToken?.call()) ?? accessToken;
413+
414+
if (tokenToSend != null) {
415+
Map<String, dynamic>? parsed;
416+
try {
417+
final decoded =
418+
base64.decode(base64.normalize(tokenToSend.split('.')[1]));
419+
parsed = json.decode(utf8.decode(decoded));
420+
} catch (e) {
421+
// ignore parsing errors
422+
}
423+
if (parsed != null && parsed['exp'] != null) {
424+
final now = (DateTime.now().millisecondsSinceEpoch / 1000).floor();
425+
final valid = now - parsed['exp'] < 0;
426+
if (!valid) {
427+
log(
428+
'auth',
429+
'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}',
430+
null,
431+
Level.FINE,
432+
);
433+
throw FormatException(
434+
'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}');
435+
}
436+
}
437+
}
438+
439+
accessToken = tokenToSend;
408440

409441
for (final channel in channels) {
410-
if (token != null) {
411-
channel.updateJoinPayload({'access_token': token});
442+
if (tokenToSend != null) {
443+
channel.updateJoinPayload({'access_token': tokenToSend});
412444
}
413445
if (channel.joinedOnce && channel.isJoined) {
414-
channel.push(ChannelEvents.accessToken, {'access_token': token});
446+
channel.push(ChannelEvents.accessToken, {'access_token': tokenToSend});
415447
}
416448
}
417449
}
@@ -436,7 +468,7 @@ class RealtimeClient {
436468
if (heartbeatTimer != null) heartbeatTimer!.cancel();
437469
heartbeatTimer = Timer.periodic(
438470
Duration(milliseconds: heartbeatIntervalMs),
439-
(Timer t) => sendHeartbeat(),
471+
(Timer t) async => await sendHeartbeat(),
440472
);
441473
for (final callback in stateChangeCallbacks['open']!) {
442474
callback();
@@ -502,7 +534,7 @@ class RealtimeClient {
502534
}
503535

504536
@internal
505-
void sendHeartbeat() {
537+
Future<void> sendHeartbeat() async {
506538
if (!isConnected) {
507539
return;
508540
}
@@ -524,6 +556,6 @@ class RealtimeClient {
524556
payload: {},
525557
ref: pendingHeartbeatRef!,
526558
));
527-
setAuth(accessToken);
559+
await setAuth(accessToken);
528560
}
529561
}

packages/realtime_client/pubspec.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ dev_dependencies:
1919
lints: ^3.0.0
2020
mocktail: ^1.0.0
2121
test: ^1.16.5
22+
crypto: ^3.0.0

packages/realtime_client/test/socket_test.dart

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import 'dart:convert';
22
import 'dart:io';
33

4+
import 'package:crypto/crypto.dart';
45
import 'package:mocktail/mocktail.dart';
56
import 'package:realtime_client/realtime_client.dart';
67
import 'package:realtime_client/src/constants.dart';
@@ -16,6 +17,31 @@ typedef WebSocketChannelClosure = WebSocketChannel Function(
1617
Map<String, String> headers,
1718
);
1819

20+
/// Generate a JWT token for testing purposes
21+
///
22+
/// [exp] in seconds since Epoch
23+
String generateJwt([int? exp]) {
24+
final header = {'alg': 'HS256', 'typ': 'JWT'};
25+
26+
final now = DateTime.now();
27+
final expiry = exp ??
28+
(now.add(Duration(hours: 1)).millisecondsSinceEpoch / 1000).floor();
29+
30+
final payload = {'exp': expiry};
31+
32+
final key = 'your-256-bit-secret';
33+
34+
final encodedHeader = base64Url.encode(utf8.encode(json.encode(header)));
35+
final encodedPayload = base64Url.encode(utf8.encode(json.encode(payload)));
36+
37+
final signatureInput = '$encodedHeader.$encodedPayload';
38+
final hmac = Hmac(sha256, utf8.encode(key));
39+
final digest = hmac.convert(utf8.encode(signatureInput));
40+
final signature = base64Url.encode(digest.bytes);
41+
42+
return '$encodedHeader.$encodedPayload.$signature';
43+
}
44+
1945
void main() {
2046
const int int64MaxValue = 9223372036854775807;
2147

@@ -174,7 +200,7 @@ void main() {
174200
await Future.delayed(const Duration(milliseconds: 200));
175201
expect(opens, 1);
176202

177-
socket.sendHeartbeat();
203+
await socket.sendHeartbeat();
178204
// need to wait for event to trigger
179205
await Future.delayed(const Duration(seconds: 1));
180206
expect(lastMsg['event'], 'heartbeat');
@@ -427,12 +453,13 @@ void main() {
427453
});
428454

429455
group('setAuth', () {
430-
final updateJoinPayload = {'access_token': 'token123'};
431-
final pushPayload = {'access_token': 'token123'};
456+
final token = generateJwt();
457+
final updateJoinPayload = {'access_token': token};
458+
final pushPayload = {'access_token': token};
432459

433460
test(
434461
"sets access token, updates channels' join payload, and pushes token to channels",
435-
() {
462+
() async {
436463
final mockedChannel1 = MockChannel();
437464
when(() => mockedChannel1.joinedOnce).thenReturn(true);
438465
when(() => mockedChannel1.isJoined).thenReturn(true);
@@ -457,7 +484,9 @@ void main() {
457484
final channel1 = mockedSocket.channel(tTopic1);
458485
final channel2 = mockedSocket.channel(tTopic2);
459486

460-
mockedSocket.setAuth('token123');
487+
await mockedSocket.setAuth(token);
488+
489+
expect(mockedSocket.accessToken, token);
461490

462491
verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1);
463492
verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1);
@@ -466,6 +495,62 @@ void main() {
466495
verify(() => channel2.push(ChannelEvents.accessToken, pushPayload))
467496
.called(1);
468497
});
498+
499+
test(
500+
"sets access token, updates channels' join payload, and pushes token to channels if is not a jwt",
501+
() async {
502+
final mockedChannel1 = MockChannel();
503+
final mockedChannel2 = MockChannel();
504+
final mockedChannel3 = MockChannel();
505+
506+
when(() => mockedChannel1.joinedOnce).thenReturn(true);
507+
when(() => mockedChannel1.isJoined).thenReturn(true);
508+
when(() => mockedChannel1.push(ChannelEvents.accessToken, any()))
509+
.thenReturn(MockPush());
510+
511+
when(() => mockedChannel2.joinedOnce).thenReturn(false);
512+
when(() => mockedChannel2.isJoined).thenReturn(false);
513+
when(() => mockedChannel2.push(ChannelEvents.accessToken, any()))
514+
.thenReturn(MockPush());
515+
516+
when(() => mockedChannel3.joinedOnce).thenReturn(true);
517+
when(() => mockedChannel3.isJoined).thenReturn(true);
518+
when(() => mockedChannel3.push(ChannelEvents.accessToken, any()))
519+
.thenReturn(MockPush());
520+
521+
const tTopic1 = 'test-topic1';
522+
const tTopic2 = 'test-topic2';
523+
const tTopic3 = 'test-topic3';
524+
525+
final mockedSocket = SocketWithMockedChannel(socketEndpoint);
526+
mockedSocket.mockedChannelLooker.addAll(<String, RealtimeChannel>{
527+
tTopic1: mockedChannel1,
528+
tTopic2: mockedChannel2,
529+
tTopic3: mockedChannel3,
530+
});
531+
532+
final channel1 = mockedSocket.channel(tTopic1);
533+
final channel2 = mockedSocket.channel(tTopic2);
534+
final channel3 = mockedSocket.channel(tTopic3);
535+
536+
const token = 'sb-key';
537+
final pushPayload = {'access_token': token};
538+
final updateJoinPayload = {'access_token': token};
539+
540+
await mockedSocket.setAuth(token);
541+
542+
expect(mockedSocket.accessToken, token);
543+
544+
verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1);
545+
verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1);
546+
verify(() => channel3.updateJoinPayload(updateJoinPayload)).called(1);
547+
548+
verify(() => channel1.push(ChannelEvents.accessToken, pushPayload))
549+
.called(1);
550+
verifyNever(() => channel2.push(ChannelEvents.accessToken, pushPayload));
551+
verify(() => channel3.push(ChannelEvents.accessToken, pushPayload))
552+
.called(1);
553+
});
469554
});
470555

471556
group('sendHeartbeat', () {
@@ -496,18 +581,18 @@ void main() {
496581

497582
//! Unimplemented Test: closes socket when heartbeat is not ack'd within heartbeat window
498583

499-
test('pushes heartbeat data when connected', () {
584+
test('pushes heartbeat data when connected', () async {
500585
mockedSocket.connState = SocketStates.open;
501586

502-
mockedSocket.sendHeartbeat();
587+
await mockedSocket.sendHeartbeat();
503588

504589
verify(() => mockedSink.add(captureAny(that: equals(data)))).called(1);
505590
});
506591

507-
test('no ops when not connected', () {
592+
test('no ops when not connected', () async {
508593
mockedSocket.connState = SocketStates.connecting;
509594

510-
mockedSocket.sendHeartbeat();
595+
await mockedSocket.sendHeartbeat();
511596
verifyNever(() => mockedSink.add(any()));
512597
});
513598
});

packages/supabase/lib/src/supabase_client.dart

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ class SupabaseClient {
332332
logLevel: options.logLevel,
333333
httpClient: _authHttpClient,
334334
timeout: options.timeout ?? RealtimeConstants.defaultTimeout,
335+
customAccessToken: accessToken,
335336
);
336337
}
337338

@@ -349,22 +350,32 @@ class SupabaseClient {
349350
void _listenForAuthEvents() {
350351
// ignore: invalid_use_of_internal_member
351352
_authStateSubscription = auth.onAuthStateChangeSync.listen(
352-
(data) {
353-
_handleTokenChanged(data.event, data.session?.accessToken);
353+
(data) async {
354+
await _handleTokenChanged(data.event, data.session?.accessToken);
354355
},
355356
onError: (error, stack) {},
356357
);
357358
}
358359

359-
void _handleTokenChanged(AuthChangeEvent event, String? token) {
360+
Future<void> _handleTokenChanged(AuthChangeEvent event, String? token) async {
360361
if (event == AuthChangeEvent.initialSession ||
361362
event == AuthChangeEvent.tokenRefreshed ||
362363
event == AuthChangeEvent.signedIn) {
363-
realtime.setAuth(token);
364+
try {
365+
await realtime.setAuth(token);
366+
} on FormatException catch (e) {
367+
if (e.message.contains('InvalidJWTToken')) {
368+
// The exception is thrown by RealtimeClient when the token is
369+
// expired for example on app launch after the app has been closed
370+
// for a while.
371+
} else {
372+
rethrow;
373+
}
374+
}
364375
} else if (event == AuthChangeEvent.signedOut) {
365376
// Token is removed
366377

367-
realtime.setAuth(_supabaseKey);
378+
await realtime.setAuth(_supabaseKey);
368379
}
369380
}
370381
}

0 commit comments

Comments
 (0)