Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import app.bsky.notification.ListNotificationsResponse
import app.bsky.notification.Preference
import app.bsky.notification.PutPreferencesV2Request
import app.bsky.notification.UpdateSeenRequest
import com.atproto.server.GetServiceAuthQueryParams
import com.tunjid.heron.data.InternalEndpoints
import com.tunjid.heron.data.core.models.Block
import com.tunjid.heron.data.core.models.Cursor
Expand All @@ -51,6 +52,7 @@ import com.tunjid.heron.data.core.models.shouldShowNotification
import com.tunjid.heron.data.core.models.value
import com.tunjid.heron.data.core.types.MutedThreadException
import com.tunjid.heron.data.core.types.NotificationFilteredOutException
import com.tunjid.heron.data.core.types.PostUri
import com.tunjid.heron.data.core.types.ProfileId
import com.tunjid.heron.data.core.types.RecordUri
import com.tunjid.heron.data.core.types.RepostUri
Expand Down Expand Up @@ -86,13 +88,15 @@ import dev.zacsweers.metro.Inject
import io.ktor.client.HttpClient
import io.ktor.client.plugins.DefaultRequest
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.request.bearerAuth
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.http.takeFrom
import kotlin.time.Clock
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds
import kotlin.time.Instant
import kotlinx.coroutines.CoroutineDispatcher
Expand All @@ -112,6 +116,8 @@ import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.plus
import kotlinx.serialization.Serializable
import sh.christian.ozone.api.Did
import sh.christian.ozone.api.Nsid
import sh.christian.ozone.api.response.AtpResponse

@Serializable
Expand All @@ -120,6 +126,7 @@ data class NotificationsQuery(
) : CursorQuery {
data class Push(
val senderId: ProfileId,
val targetDid: ProfileId,
val recordUri: RecordUri,
val reason: Notification.Reason,
)
Expand Down Expand Up @@ -302,18 +309,37 @@ internal class OfflineNotificationsRepository @Inject constructor(
token: String,
) = savedStateDataSource.inCurrentProfileSession { signedProfileId ->
if (signedProfileId == null) return@inCurrentProfileSession expiredSessionOutcome()
val saveNotificationTokenRequest = SaveNotificationTokenRequest(
did = signedProfileId.id,
token = token,
)
networkMonitor.runCatchingWithNetworkRetry(
block = {
notificationsClient.post(SaveNotificationTokenPath) {
contentType(ContentType.Application.Json)
setBody(saveNotificationTokenRequest)
}
},
).toOutcome()

networkService.runCatchingWithMonitoredNetworkRetry {
getServiceAuth(
GetServiceAuthQueryParams(
aud = Did(signedProfileId.id),
exp = Clock.System.now().epochSeconds + 5.minutes.inWholeSeconds,
lxm = Nsid(PostUri.NAMESPACE),
),
)
}.mapToResult { tokenResponse ->
val saveNotificationTokenRequest = SaveNotificationTokenRequest(
did = signedProfileId.id,
token = token,
otherDids = savedStateDataSource.savedState
.value.pastSessions
?.mapNotNull {
if (it.profileId == signedProfileId) null
else it.profileId.id
}
.orEmpty(),
)
networkMonitor.runCatchingWithNetworkRetry(
block = {
notificationsClient.post(SaveNotificationTokenPath) {
contentType(ContentType.Application.Json)
setBody(saveNotificationTokenRequest)
bearerAuth(tokenResponse.token)
}
},
)
}.toOutcome()
} ?: expiredSessionOutcome()

override suspend fun updateNotificationPreferences(
Expand Down Expand Up @@ -348,7 +374,9 @@ internal class OfflineNotificationsRepository @Inject constructor(
push = update.push,
)
when (update.reason) {
Notification.Reason.JoinedStarterPack -> currentPrefs.copy(starterpackJoined = simplePref)
Notification.Reason.JoinedStarterPack -> currentPrefs.copy(
starterpackJoined = simplePref,
)
Notification.Reason.SubscribedPost -> currentPrefs.copy(subscribedPost = simplePref)
Notification.Reason.Unverified -> currentPrefs.copy(unverified = simplePref)
Notification.Reason.Verified -> currentPrefs.copy(verified = simplePref)
Expand Down Expand Up @@ -405,12 +433,9 @@ internal class OfflineNotificationsRepository @Inject constructor(
query: NotificationsQuery.Push,
): Result<Notification> =
// Push notifications can be received for any profile that has been signed in
savedStateDataSource.inPastSession(query.recordUri.profileId()) { token ->
val signedInProfileId = token.authProfileId

savedStateDataSource.inProfileSession(query.targetDid) { signedInProfileId ->
recordResolver.resolve(query.recordUri)
.mapCatchingUnlessCancelled { resolvedRecord ->

val authorEntity = profileDao.profiles(
signedInProfiledId = signedInProfileId.id,
ids = listOf(query.senderId),
Expand Down Expand Up @@ -552,7 +577,9 @@ internal class OfflineNotificationsRepository @Inject constructor(
}
}

val notificationPreferences = savedStateDataSource.savedState.value.signedNotificationPreferencesOrDefault()
val notificationPreferences = profileData.notifications
.preferences
?: NotificationPreferences.Default

val isAuthorFollowed = viewerState?.isFollowing == true

Expand Down Expand Up @@ -690,6 +717,7 @@ private fun SavedState.signedInProfileNotifications() =
private data class SaveNotificationTokenRequest(
val did: String,
val token: String,
val otherDids: List<String>,
)

private const val SaveNotificationTokenPath = "/saveNotificationToken"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,26 +458,24 @@ internal suspend fun SavedStateDataSource.updateSignedInUserNotifications(
* Runs the [block] in the context of a single profile's session
*/
internal suspend inline fun <T> SavedStateDataSource.inCurrentProfileSession(
crossinline block: suspend (ProfileId?) -> T,
crossinline block: suspend SessionContext.Current.(ProfileId?) -> T,
): T? {
val state = savedState.first { it != InitialSavedState }
val currentProfileId = state.signedInProfileId
val profileData = currentProfileId?.let { state.profileData(it) }

return withContext(
SessionContext.Current(
tokens = profileData?.auth,
profileData = profileData ?: SavedState.ProfileData.defaultGuestData,
),
) {
val context = SessionContext.Current(
tokens = profileData?.auth,
profileData = profileData ?: SavedState.ProfileData.defaultGuestData,
)
return withContext(context) {
coroutineScope {
select {
async {
savedState.first { it.signedInProfileId != currentProfileId }
null
}.onAwait { it }
async {
block(currentProfileId)
block(context, currentProfileId)
}.onAwait { it }
}.also { coroutineContext.cancelChildren() }
}
Expand Down Expand Up @@ -545,20 +543,33 @@ internal inline fun <T> SavedStateDataSource.singleSessionFlow(
*/
internal suspend inline fun <T> SavedStateDataSource.inPastSession(
profileId: ProfileId,
crossinline block: suspend (SavedState.AuthTokens.Authenticated) -> T,
crossinline block: suspend SessionContext.Previous.() -> T,
): T? {
val state = savedState.first { it != InitialSavedState }

val profileData = state.profileData(profileId) ?: return null
val auth = profileData.auth as? SavedState.AuthTokens.Authenticated ?: return null
val context = SessionContext.Previous(
tokens = auth,
profileData = profileData,
)
return withContext(context) {
block(context)
}
}

return withContext(
SessionContext.Previous(
tokens = auth,
profileData = profileData,
),
) {
block(auth)
internal suspend inline fun <T> SavedStateDataSource.inProfileSession(
profileId: ProfileId,
crossinline block: suspend SessionContext.(ProfileId) -> T,
): T? {
val state = savedState.first { it != InitialSavedState }
val currentProfileId = state.signedInProfileId

return if (currentProfileId == profileId) inCurrentProfileSession { signedInProfileId ->
if (signedInProfileId == null) null else block(profileId)
}
else inPastSession(profileId) {
block(profileId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ class AndroidNotifier(
private val notificationManager = NotificationManagerCompat.from(context)

override suspend fun displayNotifications(notifications: List<Notification>) {
val currentLifecycleState = ProcessLifecycleOwner.get().lifecycle.currentStateFlow.value

// Show notifications in the background only
if (currentLifecycleState.isAtLeast(Lifecycle.State.RESUMED)) return
// TODO: When in app notifications are supported, check process lifecycle
// before displaying in the notifications tray. Till then, display in tray.

if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU ||
ContextCompat.checkSelfPermission(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ sealed class NotificationAction(
val senderDid = payload[NotificationAtProtoSenderDid]
?.let(::ProfileId)

val targetDid = payload[NotificationAtProtoTargetDid]
?.let(::ProfileId)

val recordUri: RecordUri? = payload[NotificationAtProtoRecordUri]
?.let { "${Uri.Host.AtProto.prefix}$it" }
?.asRecordUriOrNull()
Expand Down Expand Up @@ -164,7 +167,7 @@ private fun Flow<NotificationAction.RegisterToken>.registerTokenMutations(
// support updating a queued write to something else. For now, just write
// using the app scope and fix in a follow up PR.
val state = currentState()
if (state.hasNotificationPermissions && action.token != state.notificationToken) {
if (state.hasNotificationPermissions) {
val tokenRegistrationOutcome = notificationsRepository.registerPushNotificationToken(
action.token,
)
Expand All @@ -186,6 +189,7 @@ private fun Flow<NotificationAction.HandleNotification>.handleNotificationMutati
.flatMapMerge(NotificationProcessingMaxConcurrencyLimit) { action ->

val senderId = action.senderDid ?: return@flatMapMerge emptyFlow()
val targetDid = action.targetDid ?: return@flatMapMerge emptyFlow()
val recordUri = action.recordUri ?: return@flatMapMerge emptyFlow()
val reason = action.reason ?: return@flatMapMerge emptyFlow()

Expand All @@ -196,6 +200,7 @@ private fun Flow<NotificationAction.HandleNotification>.handleNotificationMutati
notificationsRepository.resolvePushNotification(
NotificationsQuery.Push(
senderId = senderId,
targetDid = targetDid,
recordUri = recordUri,
reason = reason,
),
Expand Down Expand Up @@ -250,6 +255,7 @@ private fun Flow<NotificationAction.RequestedNotificationPermission>.markNotific
}

private const val NotificationAtProtoSenderDid = "senderDid"
private const val NotificationAtProtoTargetDid = "targetDid"
private const val NotificationAtProtoRecordUri = "recordUri"
private const val NotificationAtProtoReason = "reason"
private const val NotificationProcessingMaxConcurrencyLimit = 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,4 @@ private fun TiledList<TimelineQuery, TimelineItem>.filterThreadDuplicates(): Til
.distinctBy(TimelineItem::id)
}

private val EMPTY_STATE_DELAY = 1.4.seconds
private val EMPTY_STATE_DELAY = 2.2.seconds
Loading