Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ import com.atproto.server.RefreshSessionResponse
import com.tunjid.heron.data.core.models.OauthUriRequest
import com.tunjid.heron.data.core.models.Server
import com.tunjid.heron.data.core.models.SessionRequest
import com.tunjid.heron.data.core.types.GenericUri
import com.tunjid.heron.data.core.types.ProfileHandle
import com.tunjid.heron.data.core.types.ProfileId
import com.tunjid.heron.data.lexicons.XrpcBlueskyApi
import com.tunjid.heron.data.lexicons.XrpcSerializersModule
import com.tunjid.heron.data.network.oauth.DpopKeyPair
import com.tunjid.heron.data.network.oauth.OAuthApi
import com.tunjid.heron.data.network.oauth.OAuthAuthorizationRequest
import com.tunjid.heron.data.network.oauth.OAuthClient
import com.tunjid.heron.data.network.oauth.OAuthScope
import com.tunjid.heron.data.network.oauth.OAuthToken
Expand Down Expand Up @@ -65,6 +62,7 @@ import io.ktor.http.encodedPath
import io.ktor.http.isSuccess
import io.ktor.http.set
import io.ktor.http.takeFrom
import kotlin.time.Clock
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.first
Expand All @@ -77,9 +75,9 @@ import sh.christian.ozone.api.runtime.buildXrpcJsonConfiguration

internal interface SessionManager {

suspend fun startOauthSessionUri(
suspend fun initiateOauthSession(
request: OauthUriRequest,
): GenericUri
): SavedState.AuthTokens.Pending

suspend fun createSession(
request: SessionRequest,
Expand Down Expand Up @@ -125,25 +123,26 @@ internal class PersistedSessionManager @Inject constructor(
client = authHttpClient,
)

private var pendingOauthSession: OauthSession? = null

override suspend fun startOauthSessionUri(
override suspend fun initiateOauthSession(
request: OauthUriRequest,
): GenericUri {
): SavedState.AuthTokens.Pending {
sessionRequestUrl.update { Url(request.server.endpoint) }
return oAuthApi.buildAuthorizationRequest(
oauthClient = HeronOauthClient,
scopes = HeronOauthScopes,
loginHandleHint = request.handle.id,
)
.also {
pendingOauthSession = OauthSession(
handle = request.handle,
request = it,
.let {
SavedState.AuthTokens.Pending.DPoP(
profileHandle = request.handle,
endpoint = request.server.endpoint,
authorizeRequestUrl = it.authorizeRequestUrl,
codeVerifier = it.codeVerifier,
nonce = it.nonce,
state = it.state,
expiresAt = Clock.System.now() + it.expiresIn,
)
}
.authorizeRequestUrl
.let(::GenericUri)
}

override suspend fun createSession(
Expand Down Expand Up @@ -171,36 +170,44 @@ internal class PersistedSessionManager @Inject constructor(
}
.requireResponse()
is SessionRequest.Oauth -> {
val pendingRequest = pendingOauthSession
?: throw IllegalStateException("Expired authentication session")
val existingAuth = savedStateDataSource.savedState.value.auth
val pendingRequest = existingAuth as? SavedState.AuthTokens.Pending.DPoP
?: throw IllegalStateException("No pending oauth session to finalize. Current auth state: $existingAuth")

try {
val callbackUrl = Url(request.callbackUri.uri)
require(request.server.endpoint == pendingRequest.endpoint) {
"Mismatched server endpoints in OAuth flow. Expected ${pendingRequest.endpoint}, but got ${request.server.endpoint}"
}

val code = callbackUrl.parameters[OauthCallbackUriCodeParam]
?: throw IllegalStateException("No auth code")
val callbackUrl = Url(request.callbackUri.uri)

val oAuthToken = oAuthApi.requestToken(
oauthClient = HeronOauthClient,
nonce = pendingRequest.request.nonce,
codeVerifier = pendingRequest.request.codeVerifier,
code = code,
)
val state = callbackUrl.parameters["state"]
?: throw IllegalStateException("No state in callback")

val callingDid = api.resolveHandle(
ResolveHandleQueryParams(Handle(pendingRequest.handle.id)),
)
.requireResponse()
.did
require(state == pendingRequest.state) {
"Mismatched state in OAuth callback. Expected ${pendingRequest.state}, but got $state"
}

if (oAuthToken.subject != callingDid) {
throw IllegalStateException("Invalid login session")
}
val code = callbackUrl.parameters[OauthCallbackUriCodeParam]
?: throw IllegalStateException("No auth code")

oAuthToken.toAppToken(authEndpoint = request.server.endpoint)
} finally {
pendingOauthSession = null
val oAuthToken = oAuthApi.requestToken(
oauthClient = HeronOauthClient,
nonce = pendingRequest.nonce,
codeVerifier = pendingRequest.codeVerifier,
code = code,
)

val callingDid = api.resolveHandle(
ResolveHandleQueryParams(Handle(pendingRequest.profileHandle.id)),
)
.requireResponse()
.did

if (oAuthToken.subject != callingDid) {
throw IllegalStateException("Invalid login session")
}

oAuthToken.toAppToken(authEndpoint = request.server.endpoint)
}
is SessionRequest.Guest -> SavedState.AuthTokens.Guest(
server = request.server,
Expand All @@ -220,6 +227,7 @@ internal class PersistedSessionManager @Inject constructor(
keyPair = authTokens.toKeyPair(),
)
is SavedState.AuthTokens.Guest,
is SavedState.AuthTokens.Pending,
null,
-> Unit
}
Expand Down Expand Up @@ -502,6 +510,7 @@ private val SavedState.AuthTokens?.defaultUrl
is SavedState.AuthTokens.Authenticated.Bearer -> authEndpoint
is SavedState.AuthTokens.Authenticated.DPoP -> issuerEndpoint
is SavedState.AuthTokens.Guest -> server.endpoint
is SavedState.AuthTokens.Pending.DPoP -> endpoint
null -> Server.BlueSky.endpoint
}

Expand All @@ -511,11 +520,6 @@ private val SavedState.AuthTokens.Authenticated.singleAccessKey
is SavedState.AuthTokens.Authenticated.DPoP -> "$auth-$refresh"
}

private class OauthSession(
val handle: ProfileHandle,
val request: OAuthAuthorizationRequest,
)

internal val BlueskyJson: Json = Json(
from = buildXrpcJsonConfiguration(XrpcSerializersModule),
builderAction = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@ internal class AuthTokenRepository(
override suspend fun oauthRequestUri(
request: OauthUriRequest,
): Result<GenericUri> = runCatchingUnlessCancelled {
sessionManager.startOauthSessionUri(request)
when (val pendingToken = sessionManager.initiateOauthSession(request)) {
is SavedState.AuthTokens.Pending.DPoP -> {
savedStateDataSource.setAuth(pendingToken)
pendingToken.authorizeRequestUrl
.let(::GenericUri)
}
}
}

override suspend fun createSession(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.tunjid.heron.data.core.models.ContentLabelPreference
import com.tunjid.heron.data.core.models.Label
import com.tunjid.heron.data.core.models.Preferences
import com.tunjid.heron.data.core.models.Server
import com.tunjid.heron.data.core.types.ProfileHandle
import com.tunjid.heron.data.core.types.ProfileId
import com.tunjid.heron.data.core.types.Uri
import com.tunjid.heron.data.core.utilities.Outcome
Expand Down Expand Up @@ -73,6 +74,23 @@ abstract class SavedState {
override val authProfileId: ProfileId = Constants.unknownAuthorId
}

@Serializable
sealed class Pending : AuthTokens() {

@Serializable
data class DPoP(
val profileHandle: ProfileHandle,
val endpoint: String,
val authorizeRequestUrl: String,
val codeVerifier: String,
val nonce: String,
val state: String,
val expiresAt: Instant,
) : Pending() {
override val authProfileId: ProfileId = Constants.unknownAuthorId
}
}

@Serializable
sealed class Authenticated : AuthTokens() {

Expand Down Expand Up @@ -194,13 +212,15 @@ internal fun SavedState.signedProfilePreferencesOrDefault(): Preferences =
is SavedState.AuthTokens.Authenticated.Bearer -> authTokens.authEndpoint
is SavedState.AuthTokens.Authenticated.DPoP -> authTokens.issuerEndpoint
is SavedState.AuthTokens.Guest -> authTokens.server.endpoint
is SavedState.AuthTokens.Pending.DPoP -> authTokens.endpoint
null -> Server.BlueSky.endpoint
}.let(::preferencesForUrl)

private fun SavedState.AuthTokens?.ifSignedIn(): SavedState.AuthTokens.Authenticated? =
when (this) {
is SavedState.AuthTokens.Authenticated -> this
is SavedState.AuthTokens.Guest,
is SavedState.AuthTokens.Pending,
null,
-> null
}
Expand Down Expand Up @@ -378,6 +398,6 @@ internal inline fun <T> SavedStateDataSource.singleSessionFlow(
block(signedInProfileId)
}

internal fun expiredSessionOutcome() = Outcome.Failure(ExpiredSessionException)
internal fun expiredSessionOutcome() = Outcome.Failure(ExpiredSessionException())

private object ExpiredSessionException : IOException()
private class ExpiredSessionException : IOException()
Loading