Skip to content

Commit e9c7c8c

Browse files
authored
fix(auth): extract both query and fragment from URL (#365)
* fix(auth): extract both query and fragment from URL * test: query takes precedence
1 parent f1e17ee commit e9c7c8c

File tree

3 files changed

+42
-32
lines changed

3 files changed

+42
-32
lines changed

Sources/Auth/AuthClient.swift

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -595,32 +595,30 @@ public final class AuthClient: Sendable {
595595
let params = extractParams(from: url)
596596

597597
if isPKCEFlow(url: url) {
598-
guard let code = params.first(where: { $0.name == "code" })?.value else {
598+
guard let code = params["code"] else {
599599
throw AuthError.pkce(.codeVerifierNotFound)
600600
}
601601

602602
let session = try await exchangeCodeForSession(authCode: code)
603603
return session
604604
}
605605

606-
if let errorDescription = params.first(where: { $0.name == "error_description" })?.value {
606+
if let errorDescription = params["error_description"] {
607607
throw AuthError.api(.init(errorDescription: errorDescription))
608608
}
609609

610610
guard
611-
let accessToken = params.first(where: { $0.name == "access_token" })?.value,
612-
let expiresIn = params.first(where: { $0.name == "expires_in" }).map(\.value)
613-
.flatMap(TimeInterval.init),
614-
let refreshToken = params.first(where: { $0.name == "refresh_token" })?.value,
615-
let tokenType = params.first(where: { $0.name == "token_type" })?.value
611+
let accessToken = params["access_token"],
612+
let expiresIn = params["expires_in"].flatMap(TimeInterval.init),
613+
let refreshToken = params["refresh_token"],
614+
let tokenType = params["token_type"]
616615
else {
617616
throw URLError(.badURL)
618617
}
619618

620-
let expiresAt = params.first(where: { $0.name == "expires_at" }).map(\.value)
621-
.flatMap(TimeInterval.init)
622-
let providerToken = params.first(where: { $0.name == "provider_token" })?.value
623-
let providerRefreshToken = params.first(where: { $0.name == "provider_refresh_token" })?.value
619+
let expiresAt = params["expires_at"].flatMap(TimeInterval.init)
620+
let providerToken = params["provider_token"]
621+
let providerRefreshToken = params["provider_refresh_token"]
624622

625623
let user = try await api.execute(
626624
.init(
@@ -644,7 +642,7 @@ public final class AuthClient: Sendable {
644642
try await sessionManager.update(session)
645643
eventEmitter.emit(.signedIn, session: session)
646644

647-
if let type = params.first(where: { $0.name == "type" })?.value, type == "recovery" {
645+
if let type = params["type"], type == "recovery" {
648646
eventEmitter.emit(.passwordRecovery, session: session)
649647
}
650648

@@ -1060,15 +1058,13 @@ public final class AuthClient: Sendable {
10601058

10611059
private func isImplicitGrantFlow(url: URL) -> Bool {
10621060
let fragments = extractParams(from: url)
1063-
return fragments.contains {
1064-
$0.name == "access_token" || $0.name == "error_description"
1065-
}
1061+
return fragments["access_token"] != nil || fragments["error_description"] != nil
10661062
}
10671063

10681064
private func isPKCEFlow(url: URL) -> Bool {
10691065
let fragments = extractParams(from: url)
10701066
let currentCodeVerifier = codeVerifierStorage.get()
1071-
return fragments.contains(where: { $0.name == "code" }) && currentCodeVerifier != nil
1067+
return fragments["code"] != nil && currentCodeVerifier != nil
10721068
}
10731069

10741070
private func getURLForProvider(

Sources/Auth/Internal/Helpers.swift

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
11
import Foundation
22

3-
struct Params: Hashable {
4-
var name: String
5-
var value: String
6-
}
7-
8-
func extractParams(from url: URL) -> [Params] {
3+
/// Extracts parameters encoded in the URL both in the query and fragment.
4+
func extractParams(from url: URL) -> [String: String] {
95
guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else {
10-
return []
6+
return [:]
117
}
128

9+
var result: [String: String] = [:]
10+
1311
if let fragment = components.fragment {
14-
return extractParams(from: fragment)
12+
let items = extractParams(from: fragment)
13+
for item in items {
14+
result[item.name] = item.value
15+
}
1516
}
1617

17-
if let queryItems = components.queryItems {
18-
return queryItems.map {
19-
Params(name: $0.name, value: $0.value ?? "")
18+
if let items = components.queryItems {
19+
for item in items {
20+
result[item.name] = item.value
2021
}
2122
}
2223

23-
return []
24+
return result
2425
}
2526

26-
func extractParams(from fragment: String) -> [Params] {
27+
private func extractParams(from fragment: String) -> [URLQueryItem] {
2728
let components =
2829
fragment
2930
.split(separator: "&")
@@ -33,7 +34,7 @@ func extractParams(from fragment: String) -> [Params] {
3334
components
3435
.compactMap {
3536
$0.count == 2
36-
? Params(name: String($0[0]), value: String($0[1]))
37+
? URLQueryItem(name: String($0[0]), value: String($0[1]))
3738
: nil
3839
}
3940
}

Tests/AuthTests/ExtractParamsTests.swift

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,26 @@ final class ExtractParamsTests: XCTestCase {
1313
let code = UUID().uuidString
1414
let url = URL(string: "io.supabase.flutterquickstart://login-callback/?code=\(code)")!
1515
let params = extractParams(from: url)
16-
XCTAssertEqual(params, [Params(name: "code", value: code)])
16+
XCTAssertEqual(params, ["code": code])
1717
}
1818

1919
func testExtractParamsInFragment() {
2020
let code = UUID().uuidString
2121
let url = URL(string: "io.supabase.flutterquickstart://login-callback/#code=\(code)")!
2222
let params = extractParams(from: url)
23-
XCTAssertEqual(params, [Params(name: "code", value: code)])
23+
XCTAssertEqual(params, ["code": code])
24+
}
25+
26+
func testExtractParamsInBothFragmentAndQuery() {
27+
let code = UUID().uuidString
28+
let url = URL(string: "io.supabase.flutterquickstart://login-callback/?code=\(code)#message=abc")!
29+
let params = extractParams(from: url)
30+
XCTAssertEqual(params, ["code": code, "message": "abc"])
31+
}
32+
33+
func testExtractParamsQueryTakesPrecedence() {
34+
let url = URL(string: "io.supabase.flutterquickstart://login-callback/?code=123#code=abc")!
35+
let params = extractParams(from: url)
36+
XCTAssertEqual(params, ["code": "123"])
2437
}
2538
}

0 commit comments

Comments
 (0)