Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/gotrue/lib/gotrue.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ export 'src/constants.dart'
hide Constants, GenerateLinkTypeExtended, AuthChangeEventExtended;
export 'src/gotrue_admin_api.dart';
export 'src/gotrue_client.dart';
export 'src/helper.dart' show decodeJwt, validateExp;
export 'src/types/auth_exception.dart';
export 'src/types/auth_response.dart' hide ToSnakeCase;
export 'src/types/auth_state.dart';
export 'src/types/gotrue_async_storage.dart';
export 'src/types/jwt.dart';
export 'src/types/mfa.dart';
export 'src/types/types.dart';
export 'src/types/session.dart';
Expand Down
91 changes: 91 additions & 0 deletions packages/gotrue/lib/src/base64url.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import 'dart:convert';
import 'dart:typed_data';

/// Base64URL encoding and decoding utilities for JWT operations.
/// Uses dart:convert for the core base64 operations and converts to/from base64url format.
class Base64Url {
/// Decodes a base64url encoded string to bytes
///
/// [input] The base64url encoded string to decode
/// [loose] If true, allows lenient parsing that doesn't strictly validate padding
static Uint8List decode(String input, {bool loose = false}) {
// Convert base64url to base64 by replacing characters and adding padding
String base64 = _base64urlToBase64(input);

try {
return base64Decode(base64);
} catch (e) {
if (loose) {
// Try to decode with minimal padding adjustments
return _decodeLoose(input);
}
rethrow;
}
}

/// Encodes bytes to a base64url encoded string
///
/// [data] The bytes to encode
/// [pad] If true, adds padding characters to the output
static String encode(List<int> data, {bool pad = false}) {
// Use dart:convert base64 encoding
String base64 = base64Encode(data);

// Convert base64 to base64url
String base64url = _base64ToBase64url(base64);

// Remove padding if not requested
if (!pad) {
base64url = base64url.replaceAll('=', '');
}

return base64url;
}

/// Decodes a base64url string to a UTF-8 string
static String decodeToString(String input, {bool loose = false}) {
final bytes = decode(input, loose: loose);
return utf8.decode(bytes);
}

/// Encodes a UTF-8 string to base64url
static String encodeFromString(String input, {bool pad = false}) {
final bytes = utf8.encode(input);
return encode(bytes, pad: pad);
}

/// Converts base64url to base64 format
static String _base64urlToBase64(String base64url) {
// Replace base64url characters with base64 characters
String base64 = base64url.replaceAll('-', '+').replaceAll('_', '/');

// Add padding if needed
int paddingLength = (4 - (base64.length % 4)) % 4;
return base64 + '=' * paddingLength;
}

/// Converts base64 to base64url format
static String _base64ToBase64url(String base64) {
// Replace characters (keep padding as-is)
return base64.replaceAll('+', '-').replaceAll('/', '_');
}

/// Loose decoding for malformed base64url strings
static Uint8List _decodeLoose(String input) {
// Try to fix common issues and decode
String fixed = input;

// Add minimal padding if needed
if (fixed.length % 4 != 0) {
fixed += '=' * (4 - (fixed.length % 4));
}

String base64 = _base64urlToBase64(fixed);

try {
return base64Decode(base64);
} catch (e) {
throw FormatException('Invalid base64url string: $input');
}
}
}
3 changes: 3 additions & 0 deletions packages/gotrue/lib/src/constants.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class Constants {

/// The name of the header that contains API version.
static const apiVersionHeaderName = 'x-supabase-api-version';

/// The TTL for the JWKS cache.
static const jwksTtl = Duration(minutes: 10);
}

class ApiVersions {
Expand Down
124 changes: 124 additions & 0 deletions packages/gotrue/lib/src/gotrue_client.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'dart:async';
import 'dart:convert';
import 'dart:math';
import 'dart:typed_data';

import 'package:collection/collection.dart';
import 'package:gotrue/gotrue.dart';
Expand All @@ -13,6 +14,7 @@ import 'package:http/http.dart';
import 'package:jwt_decode/jwt_decode.dart';
import 'package:logging/logging.dart';
import 'package:meta/meta.dart';
import 'package:pointycastle/export.dart';
import 'package:retry/retry.dart';
import 'package:rxdart/subjects.dart';

Expand Down Expand Up @@ -58,6 +60,9 @@ class GoTrueClient {
/// Completer to combine multiple simultaneous token refresh requests.
Completer<AuthResponse>? _refreshTokenCompleter;

JWKSet? _jwks;
DateTime? _jwksCachedAt;

final _onAuthStateChangeController = BehaviorSubject<AuthState>();
final _onAuthStateChangeControllerSync =
BehaviorSubject<AuthState>(sync: true);
Expand Down Expand Up @@ -1336,4 +1341,123 @@ class GoTrueClient {
);
return exception;
}

Future<JWK?> _fetchJwk(String kid, JWKSet suppliedJwks) async {
// try fetching from the supplied jwks
final jwk = suppliedJwks.keys.firstWhereOrNull((jwk) => jwk.kid == kid);
if (jwk != null) {
return jwk;
}

final now = DateTime.now();

// try fetching from cache
final cachedJwk = _jwks?.keys.firstWhereOrNull((jwk) => jwk.kid == kid);

// jwks exists and it isn't stale
if (cachedJwk != null &&
_jwksCachedAt != null &&
_jwksCachedAt!.add(Constants.jwksTtl).isAfter(now)) {
return cachedJwk;
}

// jwk isn't cached in memory so we need to fetch it from the well-known endpoint
final jwksResponse = await _fetch.request(
'$_url/.well-known/jwks.json',
RequestMethodType.get,
options: GotrueRequestOptions(headers: _headers),
);

final jwks = JWKSet.fromJson(jwksResponse as Map<String, dynamic>);

if (jwks.keys.isEmpty) {
return null;
}

_jwks = jwks;
_jwksCachedAt = now;

// find the signing key
return jwks.keys.firstWhereOrNull((jwk) => jwk.kid == kid);
}

/// Extracts the JWT claims present in the access token by first verifying the
/// JWT against the server's JSON Web Key Set endpoint
/// `/.well-known/jwks.json` which is often cached, resulting in significantly
/// faster responses. Prefer this method over [getUser] which always
/// sends a request to the Auth server for each JWT.
///
/// If the project is not using an asymmetric JWT signing key (like ECC or
/// RSA) it always sends a request to the Auth server (similar to [getUser]) to verify the JWT.
/// [jwt] An optional specific JWT you wish to verify, not the one you
/// can obtain from [currentSession].
/// [options] Various additional options that allow you to customize the
/// behavior of this method.
///
/// Returns a [GetClaimsResponse] containing the JWT claims, or throws an [AuthException] on error.
Future<GetClaimsResponse> getClaims([
String? jwt,
GetClaimsOptions? options,
]) async {
String token = jwt ?? '';

if (token.isEmpty) {
final session = currentSession;
if (session == null) {
throw AuthSessionMissingException('No session found');
}
token = session.accessToken;
}

// Decode the JWT to get the payload
final decoded = decodeJwt(token);

// Validate expiration unless allowExpired is true
if (!(options?.allowExpired ?? false)) {
validateExp(decoded.payload.exp);
}

// For symmetric algorithms (HS256, HS384, HS512) or missing kid, use server verification
if (decoded.header.kid == null || decoded.header.alg.startsWith('HS')) {
await getUser(token);
return GetClaimsResponse(
claims: decoded.payload,
header: decoded.header,
signature: decoded.signature);
}

final signingKey =
(decoded.header.kid == null || decoded.header.alg.startsWith('HS'))
? null
: await _fetchJwk(decoded.header.kid!, _jwks!);

// If symmetric algorithm, fallback to getUser()
if (signingKey == null) {
await getUser(token);
return GetClaimsResponse(
claims: decoded.payload,
header: decoded.header,
signature: decoded.signature);
}

final publicKey = RSAPublicKey(signingKey['n'], signingKey['e']);
final signer = RSASigner(SHA256Digest(), '0609608648016503040201'); // PKCS1
signer.init(false, PublicKeyParameter<RSAPublicKey>(publicKey));

final signature = RSASignature(Uint8List.fromList(decoded.signature));
final isValidSignature = signer.verifySignature(
Uint8List.fromList(
utf8.encode('${decoded.raw.header}.${decoded.raw.payload}')),
signature,
);

if (!isValidSignature) {
throw AuthException('Invalid JWT signature');
}

return GetClaimsResponse(
claims: decoded.payload,
header: decoded.header,
signature: decoded.signature);
}
}
60 changes: 60 additions & 0 deletions packages/gotrue/lib/src/helper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import 'dart:convert';
import 'dart:math';

import 'package:crypto/crypto.dart';
import 'package:gotrue/src/base64url.dart';
import 'package:gotrue/src/types/auth_exception.dart';
import 'package:gotrue/src/types/jwt.dart';

/// Converts base 10 int into String representation of base 16 int and takes the last two digets.
String dec2hex(int dec) {
Expand Down Expand Up @@ -30,3 +33,60 @@ void validateUuid(String id) {
throw ArgumentError('Invalid id: $id, must be a valid UUID');
}
}

/// Decodes a JWT token without performing validation
///
/// Returns a [DecodedJwt] containing the header, payload, signature, and raw parts.
/// Throws [AuthInvalidJwtException] if the JWT structure is invalid.
DecodedJwt decodeJwt(String token) {
final parts = token.split('.');
if (parts.length != 3) {
throw AuthInvalidJwtException('Invalid JWT structure');
}

final rawHeader = parts[0];
final rawPayload = parts[1];
final rawSignature = parts[2];

try {
// Decode header
final headerJson = Base64Url.decodeToString(rawHeader, loose: true);
final header = JwtHeader.fromJson(json.decode(headerJson));

// Decode payload
final payloadJson = Base64Url.decodeToString(rawPayload, loose: true);
final payload = JwtPayload.fromJson(json.decode(payloadJson));

// Decode signature
final signature = Base64Url.decode(rawSignature, loose: true);

return DecodedJwt(
header: header,
payload: payload,
signature: signature,
raw: JwtRawParts(
header: rawHeader,
payload: rawPayload,
signature: rawSignature,
),
);
} catch (e) {
if (e is AuthInvalidJwtException) {
rethrow;
}
throw AuthInvalidJwtException('Failed to decode JWT: $e');
}
}

/// Validates the expiration time of a JWT
///
/// Throws [AuthException] if the exp claim is missing or the JWT has expired.
void validateExp(int? exp) {
if (exp == null) {
throw AuthException('Missing exp claim');
}
final timeNow = DateTime.now().millisecondsSinceEpoch / 1000;
if (exp <= timeNow) {
throw AuthException('JWT has expired');
}
}
12 changes: 12 additions & 0 deletions packages/gotrue/lib/src/types/auth_exception.dart
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,15 @@ class AuthWeakPasswordException extends AuthException {
String toString() =>
'AuthWeakPasswordException(message: $message, statusCode: $statusCode, reasons: $reasons)';
}

class AuthInvalidJwtException extends AuthException {
AuthInvalidJwtException(super.message)
: super(
statusCode: '400',
code: 'invalid_jwt',
);

@override
String toString() =>
'AuthInvalidJwtException(message: $message, statusCode: $statusCode, code: $code)';
}
Loading