diff --git a/Sources/Auth/AuthClient.swift b/Sources/Auth/AuthClient.swift index 5a36766f1..0d605063b 100644 --- a/Sources/Auth/AuthClient.swift +++ b/Sources/Auth/AuthClient.swift @@ -30,6 +30,34 @@ struct AuthClientLoggerDecorator: SupabaseLogger { } } +/// JWKS cache TTL (Time To Live) - 10 minutes +private let JWKS_TTL: TimeInterval = 10 * 60 + +/// Cached JWKS value with timestamp +private struct CachedJWKS { + let jwks: JWKS + let cachedAt: Date +} + +/// Global JWKS cache shared across all clients with the same storage key. +/// This is especially useful for shared-memory execution environments such as +/// AWS Lambda or serverless functions. Regardless of how many clients are created, +/// if they share the same storage key they will use the same JWKS cache, +/// significantly speeding up getClaims() with asymmetric JWTs. +private actor GlobalJWKSCache { + private var cache: [String: CachedJWKS] = [:] + + func get(for key: String) -> CachedJWKS? { + cache[key] + } + + func set(_ value: CachedJWKS, for key: String) { + cache[key] = value + } +} + +private let globalJWKSCache = GlobalJWKSCache() + public actor AuthClient { static var globalClientID = 0 nonisolated let clientID: AuthClientID @@ -1448,6 +1476,150 @@ public actor AuthClient { return url } + + /// Fetches a JWK from the JWKS endpoint with caching + /// Returns nil if the key is not found, allowing graceful fallback to server-side verification + private func fetchJWK(kid: String, jwks: JWKS? = nil) async throws -> JWK? { + // Try fetching from the supplied jwks + if let jwk = jwks?.keys.first(where: { $0.kid == kid }) { + return jwk + } + + let now = date() + let storageKey = configuration.storageKey ?? defaultStorageKey + + // Try fetching from global cache + if let cached = await globalJWKSCache.get(for: storageKey), + let jwk = cached.jwks.keys.first(where: { $0.kid == kid }) + { + // Check if cache is still valid (not stale) + if cached.cachedAt.addingTimeInterval(JWKS_TTL) > now { + return jwk + } + } + + // Fetch from well-known endpoint + let response = try await api.execute( + HTTPRequest( + url: configuration.url.appendingPathComponent(".well-known/jwks.json"), + method: .get + ) + ) + + let fetchedJWKS = try response.decoded(as: JWKS.self, decoder: configuration.decoder) + + // Return nil if JWKS is empty (will fallback to getUser) + guard !fetchedJWKS.keys.isEmpty else { + return nil + } + + // Cache the JWKS globally + await globalJWKSCache.set( + CachedJWKS(jwks: fetchedJWKS, cachedAt: now), + for: storageKey + ) + + // Find the signing key - return nil if not found (will fallback to getUser) + // This handles key rotation scenarios where the JWT is signed with a key not yet in the cache + return fetchedJWKS.keys.first(where: { $0.kid == kid }) + } + + /// Extracts the JWT claims present in the access token by first verifying the + /// JWT against the server's JSON Web Key Set endpoint `/.well-known/jwks.json` + /// which is often cached, resulting in significantly faster responses. Prefer + /// this method over ``user(jwt:)`` which always sends a request to the Auth + /// server for each JWT. + /// + /// If the project is not using an asymmetric JWT signing key (like ECC or RSA) + /// it always sends a request to the Auth server (similar to ``user(jwt:)``) to + /// verify the JWT. + /// + /// - Parameters: + /// - jwt: An optional specific JWT you wish to verify, not the one you can obtain from ``session``. + /// - options: Various additional options that allow you to customize the behavior of this method. + /// + /// - Returns: A `JWTClaimsResponse` containing the verified claims, header, and signature. + /// + /// - Throws: `AuthError.jwtVerificationFailed` if verification fails, or `AuthError.sessionMissing` if no session exists. + public func getClaims( + jwt: String? = nil, + options: GetClaimsOptions = GetClaimsOptions() + ) async throws -> JWTClaimsResponse { + let token: String + if let jwt { + token = jwt + } else { + guard let session = try? await session else { + throw AuthError.sessionMissing + } + token = session.accessToken + } + + guard let decodedJWT = JWT.decode(token) else { + throw AuthError.jwtVerificationFailed(message: "Invalid JWT structure") + } + + // Validate expiration unless allowExpired is true + if !options.allowExpired { + if let exp = decodedJWT.payload["exp"] as? TimeInterval { + let now = date().timeIntervalSince1970 + if exp <= now { + throw AuthError.jwtVerificationFailed(message: "JWT has expired") + } + } + } + + let alg = decodedJWT.header["alg"] as? String + let kid = decodedJWT.header["kid"] as? String + + // Try to fetch the signing key for asymmetric JWTs + // Returns nil if: no alg, symmetric algorithm (HS256/HS512), no kid, or key not found in JWKS + let signingKey: JWK? + if let alg, !alg.hasPrefix("HS"), let kid { + // Only attempt to fetch JWK for asymmetric algorithms with a kid + signingKey = try await fetchJWK(kid: kid, jwks: options.jwks) + } else { + signingKey = nil + } + + // If no signing key available (symmetric algorithm, RS256, no kid, or key not found), + // fallback to server-side verification via getUser() + guard + let signingKey, + let alg = signingKey.alg, + let algorithm = JWTAlgorithm(rawValue: alg) + else { + _ = try await user(jwt: token) + // getUser succeeds, so claims can be trusted + let claims = try configuration.decoder.decode( + JWTClaims.self, + from: JSONSerialization.data(withJSONObject: decodedJWT.payload) + ) + let header = try configuration.decoder.decode( + JWTHeader.self, + from: JSONSerialization.data(withJSONObject: decodedJWT.header) + ) + return JWTClaimsResponse(claims: claims, header: header, signature: decodedJWT.signature) + } + + let isValid = algorithm.verify(jwt: decodedJWT, jwk: signingKey) + + guard isValid else { + throw AuthError.jwtVerificationFailed(message: "Invalid JWT signature") + } + + // Decode claims and header + let claims = try configuration.decoder.decode( + JWTClaims.self, + from: JSONSerialization.data(withJSONObject: decodedJWT.payload) + ) + let header = try configuration.decoder.decode( + JWTHeader.self, + from: JSONSerialization.data(withJSONObject: decodedJWT.header) + ) + + return JWTClaimsResponse(claims: claims, header: header, signature: decodedJWT.signature) + } } extension AuthClient { diff --git a/Sources/Auth/AuthError.swift b/Sources/Auth/AuthError.swift index 5349d36f7..991100bf5 100644 --- a/Sources/Auth/AuthError.swift +++ b/Sources/Auth/AuthError.swift @@ -114,6 +114,7 @@ extension ErrorCode { //#nosec G101 -- Not a secret value. public static let invalidCredentials = ErrorCode("invalid_credentials") public static let emailAddressNotAuthorized = ErrorCode("email_address_not_authorized") + public static let invalidJWT = ErrorCode("invalid_jwt") } public enum AuthError: LocalizedError, Equatable { @@ -261,13 +262,17 @@ public enum AuthError: LocalizedError, Equatable { /// Error thrown when an error happens during implicit grant flow. case implicitGrantRedirect(message: String) + /// Error thrown when JWT verification fails. + case jwtVerificationFailed(message: String) + public var message: String { switch self { case .sessionMissing: "Auth session missing." case let .weakPassword(message, _), let .api(message, _, _, _), let .pkceGrantCodeExchange(message, _, _), - let .implicitGrantRedirect(message): + let .implicitGrantRedirect(message), + let .jwtVerificationFailed(message): message // Deprecated cases case .missingExpClaim: "Missing expiration claim in the access token." @@ -283,6 +288,7 @@ public enum AuthError: LocalizedError, Equatable { case .weakPassword: .weakPassword case let .api(_, errorCode, _, _): errorCode case .pkceGrantCodeExchange, .implicitGrantRedirect: .unknown + case .jwtVerificationFailed: .invalidJWT // Deprecated cases case .missingExpClaim, .malformedJWT, .invalidRedirectScheme, .missingURL: .unknown } diff --git a/Sources/Auth/Internal/JWK+RSA.swift b/Sources/Auth/Internal/JWK+RSA.swift new file mode 100644 index 000000000..10233a678 --- /dev/null +++ b/Sources/Auth/Internal/JWK+RSA.swift @@ -0,0 +1,72 @@ +// +// JWK+RSA.swift +// Supabase +// +// Created by Guilherme Souza on 07/10/25. +// + +import Foundation + +#if canImport(Security) + + extension JWK { + var rsaPublishKey: SecKey? { + guard kty == "RSA", + alg == "RS256", + let n, + let modulus = Base64URL.decode(n), + let e, + let exponent = Base64URL.decode(e) + else { + return nil + } + + let encodedKey = encodeRSAPublishKey(modulus: [UInt8](modulus), exponent: [UInt8](exponent)) + return generateRSAPublicKey(from: encodedKey) + } + } + + extension JWK { + fileprivate func encodeRSAPublishKey(modulus: [UInt8], exponent: [UInt8]) -> Data { + var prefixedModulus: [UInt8] = [0x00] // To indicate that the number is not negative + prefixedModulus.append(contentsOf: modulus) + let encodedModulus = prefixedModulus.derEncode(as: 2) // Integer + let encodedExponent = exponent.derEncode(as: 2) // Integer + let encodedSequence = (encodedModulus + encodedExponent).derEncode(as: 48) // Sequence + return Data(encodedSequence) + } + + fileprivate func generateRSAPublicKey(from derEncodedData: Data) -> SecKey? { + let sizeInBits = derEncodedData.count * MemoryLayout.size + let attributes: [CFString: Any] = [ + kSecAttrKeyType: kSecAttrKeyTypeRSA, + kSecAttrKeyClass: kSecAttrKeyClassPublic, + kSecAttrKeySizeInBits: NSNumber(value: sizeInBits), + kSecAttrIsPermanent: false, + ] + return SecKeyCreateWithData(derEncodedData as CFData, attributes as CFDictionary, nil) + } + } + + extension [UInt8] { + fileprivate func derEncode(as dataType: UInt8) -> [UInt8] { + var encodedBytes: [UInt8] = [dataType] + var numberOfBytes = count + if numberOfBytes < 128 { + encodedBytes.append(UInt8(numberOfBytes)) + } else { + let lengthData = Data( + bytes: &numberOfBytes, + count: MemoryLayout.size(ofValue: numberOfBytes) + ) + let lengthBytes = [UInt8](lengthData).filter({ $0 != 0 }).reversed() + encodedBytes.append(UInt8(truncatingIfNeeded: lengthBytes.count) | 0b10000000) + encodedBytes.append(contentsOf: lengthBytes) + } + encodedBytes.append(contentsOf: self) + return encodedBytes + } + + } + +#endif diff --git a/Sources/Auth/Internal/JWTAlgorithm.swift b/Sources/Auth/Internal/JWTAlgorithm.swift new file mode 100644 index 000000000..4d43f370f --- /dev/null +++ b/Sources/Auth/Internal/JWTAlgorithm.swift @@ -0,0 +1,33 @@ +// +// JWTVerifier.swift +// Supabase +// +// Created by Claude on 06/10/25. +// + +import Foundation + +enum JWTAlgorithm: String { + case rs256 = "RS256" + + func verify( + jwt: DecodedJWT, + jwk: JWK + ) -> Bool { + let message = "\(jwt.raw.header).\(jwt.raw.payload)".data(using: .utf8)! + switch self { + case .rs256: + #if canImport(Security) + return SecKeyVerifySignature( + jwk.rsaPublishKey!, + .rsaSignatureMessagePKCS1v15SHA256, + message as CFData, + jwt.signature as CFData, + nil + ) + #else + return false + #endif + } + } +} diff --git a/Sources/Auth/Types.swift b/Sources/Auth/Types.swift index 2cf82812d..f0400d4e8 100644 --- a/Sources/Auth/Types.swift +++ b/Sources/Auth/Types.swift @@ -384,11 +384,11 @@ enum VerifyOTPParams: Encodable { func encode(to encoder: any Encoder) throws { var container = encoder.singleValueContainer() switch self { - case let .email(value): + case .email(let value): try container.encode(value) - case let .mobile(value): + case .mobile(let value): try container.encode(value) - case let .tokenHash(value): + case .tokenHash(let value): try container.encode(value) } } @@ -448,20 +448,20 @@ public enum AuthResponse: Codable, Hashable, Sendable { public func encode(to encoder: any Encoder) throws { var container = encoder.singleValueContainer() switch self { - case let .session(value): try container.encode(value) - case let .user(value): try container.encode(value) + case .session(let value): try container.encode(value) + case .user(let value): try container.encode(value) } } public var user: User { switch self { - case let .session(session): session.user - case let .user(user): user + case .session(let session): session.user + case .user(let user): user } } public var session: Session? { - if case let .session(session) = self { return session } + if case .session(let session) = self { return session } return nil } } @@ -1130,3 +1130,234 @@ public struct ListOAuthClientsPaginatedResponse: Hashable, Sendable { public var lastPage: Int public var total: Int } + +// MARK: - JWT Claims + +/// JSON Web Key (JWK) representation +public struct JWK: Codable, Hashable, Sendable { + /// Key type (e.g., "RSA", "EC", "oct") + public let kty: String + /// Key operations (e.g., ["sign", "verify"]) + public let keyOps: [String]? + /// Algorithm (e.g., "RS256", "ES256", "HS256") + public let alg: String? + /// Key ID + public let kid: String? + + // RSA-specific fields + /// RSA modulus (base64url-encoded) + public let n: String? + /// RSA exponent (base64url-encoded) + public let e: String? + + // EC-specific fields + /// EC curve name (e.g., "P-256") + public let crv: String? + /// EC x coordinate (base64url-encoded) + public let x: String? + /// EC y coordinate (base64url-encoded) + public let y: String? + + // Symmetric key field + /// Symmetric key value (base64url-encoded) + public let k: String? + + enum CodingKeys: String, CodingKey { + case kty + case keyOps = "key_ops" + case alg + case kid + case n + case e + case crv + case x + case y + case k + } +} + +/// JSON Web Key Set (JWKS) +public struct JWKS: Codable, Hashable, Sendable { + public let keys: [JWK] +} + +/// JWT Header +public struct JWTHeader: Codable, Hashable, Sendable { + /// Algorithm (e.g., "RS256", "ES256", "HS256") + public let alg: String + /// Key ID + public let kid: String? + /// Type (typically "JWT") + public let typ: String? +} + +/// JWT Claims +public struct JWTClaims: Codable, Hashable, Sendable { + /// Issuer + public let iss: String? + /// Subject + public let sub: String? + /// Audience + public let aud: AudienceClaim? + /// Expiration time + public let exp: TimeInterval? + /// Issued at + public let iat: TimeInterval? + /// Not before + public let nbf: TimeInterval? + /// JWT ID + public let jti: String? + /// Role + public let role: String? + /// Authenticator Assurance Level + public let aal: String? + /// Session ID + public let sessionId: String? + /// Email + public let email: String? + /// Phone + public let phone: String? + /// App metadata + public let appMetadata: [String: AnyJSON]? + /// User metadata + public let userMetadata: [String: AnyJSON]? + /// Additional claims + public var additionalClaims: [String: AnyJSON] = [:] + + enum CodingKeys: String, CodingKey { + case iss + case sub + case aud + case exp + case iat + case nbf + case jti + case role + case aal + case sessionId = "session_id" + case email + case phone + case appMetadata = "app_metadata" + case userMetadata = "user_metadata" + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + iss = try container.decodeIfPresent(String.self, forKey: .iss) + sub = try container.decodeIfPresent(String.self, forKey: .sub) + aud = try container.decodeIfPresent(AudienceClaim.self, forKey: .aud) + exp = try container.decodeIfPresent(TimeInterval.self, forKey: .exp) + iat = try container.decodeIfPresent(TimeInterval.self, forKey: .iat) + nbf = try container.decodeIfPresent(TimeInterval.self, forKey: .nbf) + jti = try container.decodeIfPresent(String.self, forKey: .jti) + role = try container.decodeIfPresent(String.self, forKey: .role) + aal = try container.decodeIfPresent(String.self, forKey: .aal) + sessionId = try container.decodeIfPresent(String.self, forKey: .sessionId) + email = try container.decodeIfPresent(String.self, forKey: .email) + phone = try container.decodeIfPresent(String.self, forKey: .phone) + appMetadata = try container.decodeIfPresent([String: AnyJSON].self, forKey: .appMetadata) + userMetadata = try container.decodeIfPresent([String: AnyJSON].self, forKey: .userMetadata) + + // Decode additional claims + let allKeys = try decoder.container(keyedBy: AnyCodingKey.self) + var additional: [String: AnyJSON] = [:] + for key in allKeys.allKeys where CodingKeys(stringValue: key.stringValue) == nil { + if let value = try? allKeys.decode(AnyJSON.self, forKey: key) { + additional[key.stringValue] = value + } + } + additionalClaims = additional + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encodeIfPresent(self.iss, forKey: .iss) + try container.encodeIfPresent(self.sub, forKey: .sub) + try container.encodeIfPresent(self.aud, forKey: .aud) + try container.encodeIfPresent(self.exp, forKey: .exp) + try container.encodeIfPresent(self.iat, forKey: .iat) + try container.encodeIfPresent(self.nbf, forKey: .nbf) + try container.encodeIfPresent(self.jti, forKey: .jti) + try container.encodeIfPresent(self.role, forKey: .role) + try container.encodeIfPresent(self.aal, forKey: .aal) + try container.encodeIfPresent(self.sessionId, forKey: .sessionId) + try container.encodeIfPresent(self.email, forKey: .email) + try container.encodeIfPresent(self.phone, forKey: .phone) + try container.encodeIfPresent(self.appMetadata, forKey: .appMetadata) + try container.encodeIfPresent(self.userMetadata, forKey: .userMetadata) + + var additionalClaimsContainer = encoder.container(keyedBy: AnyCodingKey.self) + for (key, value) in additionalClaims { + try additionalClaimsContainer.encode(value, forKey: AnyCodingKey(stringValue: key)!) + } + } +} + +/// Audience claim can be either a string or an array of strings +public enum AudienceClaim: Codable, Hashable, Sendable { + case string(String) + case array([String]) + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + if let string = try? container.decode(String.self) { + self = .string(string) + } else if let array = try? container.decode([String].self) { + self = .array(array) + } else { + throw DecodingError.typeMismatch( + AudienceClaim.self, + DecodingError.Context( + codingPath: decoder.codingPath, + debugDescription: "Expected String or [String] for audience claim" + ) + ) + } + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .string(let value): + try container.encode(value) + case .array(let value): + try container.encode(value) + } + } +} + +private struct AnyCodingKey: CodingKey { + var stringValue: String + var intValue: Int? + + init?(stringValue: String) { + self.stringValue = stringValue + intValue = nil + } + + init?(intValue: Int) { + stringValue = "\(intValue)" + self.intValue = intValue + } +} + +/// Response from getClaims method +public struct JWTClaimsResponse: Sendable { + public let claims: JWTClaims + public let header: JWTHeader + public let signature: Data +} + +/// Options for the getClaims method +public struct GetClaimsOptions: Sendable { + /// If set to `true` the `exp` claim will not be validated against the current time. + public let allowExpired: Bool + + /// If set, this JSON Web Key Set is going to have precedence over the cached value available on the server. + public let jwks: JWKS? + + public init(allowExpired: Bool = false, jwks: JWKS? = nil) { + self.allowExpired = allowExpired + self.jwks = jwks + } +} diff --git a/Sources/Helpers/Base64URL.swift b/Sources/Helpers/Base64URL.swift new file mode 100644 index 000000000..66926eedf --- /dev/null +++ b/Sources/Helpers/Base64URL.swift @@ -0,0 +1,32 @@ +// +// Base64URL.swift +// Supabase +// +// Created by Claude on 06/10/25. +// + +import Foundation + +package enum Base64URL { + /// Decodes a base64url-encoded string to Data + package static func decode(_ value: String) -> Data? { + var base64 = value.replacingOccurrences(of: "-", with: "+") + .replacingOccurrences(of: "_", with: "/") + let length = Double(base64.lengthOfBytes(using: .utf8)) + let requiredLength = 4 * ceil(length / 4.0) + let paddingLength = requiredLength - length + if paddingLength > 0 { + let padding = "".padding(toLength: Int(paddingLength), withPad: "=", startingAt: 0) + base64 = base64 + padding + } + return Data(base64Encoded: base64, options: .ignoreUnknownCharacters) + } + + /// Encodes Data to a base64url-encoded string + package static func encode(_ data: Data) -> String { + data.base64EncodedString() + .replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") + } +} diff --git a/Sources/Helpers/JWT.swift b/Sources/Helpers/JWT.swift index 86dfb5d0c..983e231a9 100644 --- a/Sources/Helpers/JWT.swift +++ b/Sources/Helpers/JWT.swift @@ -7,6 +7,13 @@ import Foundation +package struct DecodedJWT { + package let header: [String: Any] + package let payload: [String: Any] + package let signature: Data + package let raw: (header: String, payload: String) +} + package enum JWT { package static func decodePayload(_ jwt: String) -> [String: Any]? { let parts = jwt.split(separator: ".") @@ -15,7 +22,7 @@ package enum JWT { } let payload = String(parts[1]) - guard let data = base64URLDecode(payload) else { + guard let data = Base64URL.decode(payload) else { return nil } let json = try? JSONSerialization.jsonObject(with: data, options: []) @@ -25,16 +32,31 @@ package enum JWT { return decodedPayload } - private static func base64URLDecode(_ value: String) -> Data? { - var base64 = value.replacingOccurrences(of: "-", with: "+") - .replacingOccurrences(of: "_", with: "/") - let length = Double(base64.lengthOfBytes(using: .utf8)) - let requiredLength = 4 * ceil(length / 4.0) - let paddingLength = requiredLength - length - if paddingLength > 0 { - let padding = "".padding(toLength: Int(paddingLength), withPad: "=", startingAt: 0) - base64 = base64 + padding + package static func decode(_ jwt: String) -> DecodedJWT? { + let parts = jwt.split(separator: ".") + guard parts.count == 3 else { + return nil + } + + let headerString = String(parts[0]) + let payloadString = String(parts[1]) + let signatureString = String(parts[2]) + + guard + let headerData = Base64URL.decode(headerString), + let payloadData = Base64URL.decode(payloadString), + let signatureData = Base64URL.decode(signatureString), + let headerJSON = try? JSONSerialization.jsonObject(with: headerData, options: []) as? [String: Any], + let payloadJSON = try? JSONSerialization.jsonObject(with: payloadData, options: []) as? [String: Any] + else { + return nil } - return Data(base64Encoded: base64, options: .ignoreUnknownCharacters) + + return DecodedJWT( + header: headerJSON, + payload: payloadJSON, + signature: signatureData, + raw: (header: headerString, payload: payloadString) + ) } } diff --git a/Tests/AuthTests/AuthClientTests.swift b/Tests/AuthTests/AuthClientTests.swift index 19f58bbbb..4bb919ac3 100644 --- a/Tests/AuthTests/AuthClientTests.swift +++ b/Tests/AuthTests/AuthClientTests.swift @@ -2227,6 +2227,345 @@ final class AuthClientTests: XCTestCase { XCTAssertNil(Dependencies[sut.clientID].sessionStorage.get()) } + // MARK: - getClaims Tests + + func testGetClaims_withHS256JWT_shouldFallbackAndReturnClaims() async throws { + // HS256 JWT (symmetric algorithm) - will use server-side verification + let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsImV4cCI6OTk5OTk5OTk5OSwiaWF0IjoxNTE2MjM5MDIyLCJyb2xlIjoiYXV0aGVudGljYXRlZCJ9.4Adcj0vZKqXRB_mPpDVkWvB3xw7yHYjpzGJLKFQjKEc" + + let user = User(fromMockNamed: "user") + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(user)] + ).register() + + let sut = makeSUT() + + let result = try await sut.getClaims(jwt: jwt) + + XCTAssertEqual(result.claims.sub, "1234567890") + XCTAssertEqual(result.claims.iss, "http://localhost:54321/auth/v1") + if case let .string(aud) = result.claims.aud { + XCTAssertEqual(aud, "authenticated") + } else { + XCTFail("Expected string audience") + } + XCTAssertEqual(result.claims.role, "authenticated") + XCTAssertEqual(result.header.alg, "HS256") + XCTAssertNil(result.header.kid) + } + + func testGetClaims_withoutJWT_shouldUseSessionAccessToken() async throws { + // HS256 JWT from session + let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsImV4cCI6OTk5OTk5OTk5OSwiaWF0IjoxNTE2MjM5MDIyLCJyb2xlIjoiYXV0aGVudGljYXRlZCJ9.4Adcj0vZKqXRB_mPpDVkWvB3xw7yHYjpzGJLKFQjKEc" + + var session = Session.validSession + session.accessToken = jwt + + let user = User(fromMockNamed: "user") + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(user)] + ).register() + + let sut = makeSUT() + Dependencies[sut.clientID].sessionStorage.store(session) + + let result = try await sut.getClaims() + + XCTAssertEqual(result.claims.sub, "1234567890") + XCTAssertEqual(result.claims.role, "authenticated") + } + + func testGetClaims_withProvidedJWKS_shouldStillFallbackForES256() async throws { + // ES256 is not yet supported client-side, so it will fallback to server even with JWKS + let jwt = "eyJhbGciOiJFUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsImV4cCI6OTk5OTk5OTk5OSwiaWF0IjoxNTE2MjM5MDIyLCJyb2xlIjoiYXV0aGVudGljYXRlZCJ9.dummysignature" + + // JWK is Codable, no custom init needed + let jwkDict: [String: Any] = [ + "kty": "EC", + "kid": "test-kid", + "alg": "ES256", + "crv": "P-256", + "x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM" + ] + + let jwkData = try JSONSerialization.data(withJSONObject: jwkDict) + let jwk = try AuthClient.Configuration.jsonDecoder.decode(JWK.self, from: jwkData) + let jwks = JWKS(keys: [jwk]) + + let user = User(fromMockNamed: "user") + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(user)] + ).register() + + let sut = makeSUT() + + let result = try await sut.getClaims(jwt: jwt, options: GetClaimsOptions(jwks: jwks)) + + XCTAssertEqual(result.claims.sub, "1234567890") + XCTAssertEqual(result.claims.role, "authenticated") + } + + func testGetClaims_withES256JWT_shouldFallbackToServerVerification() async throws { + // ES256 JWT without kid - will fallback to server + let jwt = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsImV4cCI6OTk5OTk5OTk5OSwiaWF0IjoxNTE2MjM5MDIyLCJyb2xlIjoiYXV0aGVudGljYXRlZCJ9.dummysignature" + + let user = User(fromMockNamed: "user") + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(user)] + ).register() + + let sut = makeSUT() + + let result = try await sut.getClaims(jwt: jwt) + + XCTAssertEqual(result.claims.sub, "1234567890") + XCTAssertEqual(result.claims.role, "authenticated") + } + + func testGetClaims_withRS256JWT_whenJWKNotFound_shouldFallbackToServerVerification() async throws { + // RS256 JWT with kid but key not in JWKS - will try to fetch JWKS, not find it, then fallback to server + let jwt = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsImV4cCI6OTk5OTk5OTk5OSwiaWF0IjoxNTE2MjM5MDIyLCJyb2xlIjoiYXV0aGVudGljYXRlZCJ9.dummysignature" + + // Mock JWKS endpoint with different kid + let jwkDict: [String: Any] = [ + "kty": "RSA", + "kid": "different-kid", + "alg": "RS256", + "n": "modulus", + "e": "AQAB" + ] + let jwkData = try JSONSerialization.data(withJSONObject: jwkDict) + let jwk = try AuthClient.Configuration.jsonDecoder.decode(JWK.self, from: jwkData) + let jwks = JWKS(keys: [jwk]) + + Mock( + url: clientURL.appendingPathComponent(".well-known/jwks.json"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(jwks)] + ).register() + + let user = User(fromMockNamed: "user") + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(user)] + ).register() + + let sut = makeSUT() + + let result = try await sut.getClaims(jwt: jwt) + + XCTAssertEqual(result.claims.sub, "1234567890") + XCTAssertEqual(result.claims.role, "authenticated") + } + + func testGetClaims_withNoKidInHeader_shouldFallbackToServerVerification() async throws { + // JWT without kid - cannot look up in JWKS + let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI5ODc2NTQzMjEiLCJpc3MiOiJodHRwOi8vbG9jYWxob3N0OjU0MzIxL2F1dGgvdjEiLCJhdWQiOiJhdXRoZW50aWNhdGVkIiwiZXhwIjo5OTk5OTk5OTk5LCJpYXQiOjE1MTYyMzkwMjIsInJvbGUiOiJhdXRoZW50aWNhdGVkIn0.YT0NvH-jYKCiN-wrAVcMmTIxZkQ3OtqTVFjJAqGcRuw" + + let user = User(fromMockNamed: "user") + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(user)] + ).register() + + let sut = makeSUT() + + let result = try await sut.getClaims(jwt: jwt) + + XCTAssertEqual(result.claims.sub, "987654321") + XCTAssertEqual(result.claims.role, "authenticated") + } + + func testGetClaims_withoutJWTAndNoSession_shouldThrowSessionMissing() async throws { + let sut = makeSUT() + + do { + _ = try await sut.getClaims() + XCTFail("Expected sessionMissing error") + } catch let error as AuthError { + guard case .sessionMissing = error else { + XCTFail("Expected sessionMissing error, got \(error)") + return + } + } catch { + XCTFail("Expected AuthError, got \(error)") + } + } + + func testGetClaims_withInvalidJWTStructure_shouldThrowJWTVerificationFailed() async throws { + let invalidJWT = "invalid.jwt.token" + + let sut = makeSUT() + + do { + _ = try await sut.getClaims(jwt: invalidJWT) + XCTFail("Expected jwtVerificationFailed error") + } catch let error as AuthError { + guard case .jwtVerificationFailed(let message) = error else { + XCTFail("Expected jwtVerificationFailed error, got \(error)") + return + } + XCTAssertEqual(message, "Invalid JWT structure") + } catch { + XCTFail("Expected AuthError, got \(error)") + } + } + + func testGetClaims_withExpiredJWT_shouldThrowJWTVerificationFailed() async throws { + // JWT with exp in the past + let expiredJWT = "eyJhbGciOiJFUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsImV4cCI6MTUxNjIzOTAyMiwiaWF0IjoxNTE2MjM5MDIyLCJyb2xlIjoiYXV0aGVudGljYXRlZCJ9.MEYCIQDmtLy0PF_lR7rJQHyKLmJKp1xFKECfVvGTBcXiVnz0jAIhAOoXZJ3kHSA2MqL1XhcUy8dWOZCr6zWCN_FXsP8qKfPR" + + let sut = makeSUT() + + do { + _ = try await sut.getClaims(jwt: expiredJWT) + XCTFail("Expected jwtVerificationFailed error") + } catch let error as AuthError { + guard case .jwtVerificationFailed(let message) = error else { + XCTFail("Expected jwtVerificationFailed error, got \(error)") + return + } + XCTAssertEqual(message, "JWT has expired") + } catch { + XCTFail("Expected AuthError, got \(error)") + } + } + + func testGetClaims_withExpiredJWTAndAllowExpired_shouldReturnClaims() async throws { + // JWT with exp in the past but allowExpired option - falls back to server + let expiredJWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsImV4cCI6MTUxNjIzOTAyMiwiaWF0IjoxNTE2MjM5MDIyLCJyb2xlIjoiYXV0aGVudGljYXRlZCJ9.aN0HLYHkp7nKZp4xWvBaDqSrCFBxk2tq0KZc4BXGqYs" + + let user = User(fromMockNamed: "user") + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(user)] + ).register() + + let sut = makeSUT() + + let result = try await sut.getClaims(jwt: expiredJWT, options: GetClaimsOptions(allowExpired: true)) + + XCTAssertEqual(result.claims.sub, "1234567890") + XCTAssertEqual(result.claims.exp, 1516239022) + } + + func testGetClaims_whenServerRejectsJWT_shouldThrowError() async throws { + // HS256 JWT that will be verified server-side + let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsImV4cCI6OTk5OTk5OTk5OSwiaWF0IjoxNTE2MjM5MDIyLCJyb2xlIjoiYXV0aGVudGljYXRlZCJ9.4Adcj0vZKqXRB_mPpDVkWvB3xw7yHYjpzGJLKFQjKEc" + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 401, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode([ + "error": "invalid_token", + "error_description": "Invalid JWT" + ])] + ).register() + + let sut = makeSUT() + + do { + _ = try await sut.getClaims(jwt: jwt) + XCTFail("Expected error from server") + } catch { + // Expected to fail + } + } + + func testGetClaims_withComplexClaims_shouldDecodeAllFields() async throws { + // JWT with multiple claim fields + // HS256 so it falls back to server verification + let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjoiYXV0aGVudGljYXRlZCIsImV4cCI6OTk5OTk5OTk5OSwiaWF0IjoxNTE2MjM5MDIyLCJuYmYiOjE1MTYyMzkwMjIsImp0aSI6InRlc3QtanRpIiwicm9sZSI6ImF1dGhlbnRpY2F0ZWQiLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20iLCJwaG9uZSI6IisxMjM0NTY3ODkwIn0.dBYm1Y-TfRjPsxw_gXqHB5zGHSH9hXS0OeFN_wL8HbA" + + let user = User(fromMockNamed: "user") + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(user)] + ).register() + + let sut = makeSUT() + + let result = try await sut.getClaims(jwt: jwt) + + XCTAssertEqual(result.claims.sub, "1234567890") + XCTAssertEqual(result.claims.iss, "http://localhost:54321/auth/v1") + if case let .string(aud) = result.claims.aud { + XCTAssertEqual(aud, "authenticated") + } else { + XCTFail("Expected string audience") + } + XCTAssertEqual(result.claims.exp, 9999999999) + XCTAssertEqual(result.claims.iat, 1516239022) + XCTAssertEqual(result.claims.nbf, 1516239022) + XCTAssertEqual(result.claims.jti, "test-jti") + XCTAssertEqual(result.claims.role, "authenticated") + XCTAssertEqual(result.claims.email, "test@example.com") + XCTAssertEqual(result.claims.phone, "+1234567890") + } + + func testGetClaims_withArrayAudience_shouldDecodeCorrectly() async throws { + // JWT with audience as array + let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo1NDMyMS9hdXRoL3YxIiwiYXVkIjpbImF1dGhlbnRpY2F0ZWQiLCJzZXJ2aWNlLXJvbGUiXSwiZXhwIjo5OTk5OTk5OTk5LCJpYXQiOjE1MTYyMzkwMjIsInJvbGUiOiJhdXRoZW50aWNhdGVkIn0.Jz-lHQoR2VsQ_vX8wKyN7mPxT4aU9cF1bYsHqGdWlIk" + + let user = User(fromMockNamed: "user") + + Mock( + url: clientURL.appendingPathComponent("user"), + ignoreQuery: true, + contentType: .json, + statusCode: 200, + data: [.get: try! AuthClient.Configuration.jsonEncoder.encode(user)] + ).register() + + let sut = makeSUT() + + let result = try await sut.getClaims(jwt: jwt) + + XCTAssertEqual(result.claims.sub, "1234567890") + XCTAssertNotNil(result.claims.aud) + } + private func makeSUT(flowType: AuthFlowType = .pkce) -> AuthClient { let sessionConfiguration = URLSessionConfiguration.default sessionConfiguration.protocolClasses = [MockingURLProtocol.self]