Skip to content

Commit b2fb6e9

Browse files
feat: force ws enabled via mdm (WPB-23228) (#3855)
* feat: implement bulk update for persistent WebSocket status via MDM * fixing build issue --------- Co-authored-by: Mohamad Jaara <mohamad.jaara@wire.com>
1 parent 46745c7 commit b2fb6e9

File tree

6 files changed

+142
-7
lines changed

6 files changed

+142
-7
lines changed

data/persistence/src/commonMain/db_global/com/wire/kalium/persistence/Accounts.sq

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ SELECT isPersistentWebSocketEnabled FROM Accounts WHERE logout_reason IS NULL AN
5757
updatePersistentWebSocketStatus:
5858
UPDATE Accounts SET isPersistentWebSocketEnabled = :isPersistentWebSocketEnabled WHERE id = :userId;
5959

60+
updateAllPersistentWebSocketStatus:
61+
UPDATE Accounts SET isPersistentWebSocketEnabled = :enabled WHERE logout_reason IS NULL;
62+
6063
updateSsoId:
6164
UPDATE Accounts SET scim_external_id = :scimExternalId, subject = :subject, tenant = :tenant WHERE id = :userId;
6265

data/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/daokaliumdb/AccountsDAO.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ interface AccountsDAO {
174174
suspend fun deleteAccount(userIDEntity: UserIDEntity)
175175
suspend fun markAccountAsInvalid(userIDEntity: UserIDEntity, logoutReason: LogoutReason)
176176
suspend fun updatePersistentWebSocketStatus(userIDEntity: UserIDEntity, isPersistentWebSocketEnabled: Boolean)
177+
suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean)
177178
suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean
178179
suspend fun accountInfo(userIDEntity: UserIDEntity): AccountInfoEntity?
179180
fun fullAccountInfo(userIDEntity: UserIDEntity): FullAccountEntity?
@@ -304,6 +305,12 @@ internal class AccountsDAOImpl internal constructor(
304305
}
305306
}
306307

308+
override suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean) {
309+
withContext(queriesContext) {
310+
queries.updateAllPersistentWebSocketStatus(enabled)
311+
}
312+
}
313+
307314
override suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean = withContext(queriesContext) {
308315
queries.persistentWebSocketStatus(userIDEntity).executeAsOne()
309316
}

data/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/globalDB/AccountsDAOTest.kt

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,9 @@ import com.wire.kalium.persistence.db.GlobalDatabaseBuilder
2929
import com.wire.kalium.persistence.model.LogoutReason
3030
import com.wire.kalium.persistence.model.ServerConfigEntity
3131
import com.wire.kalium.persistence.model.SsoIdEntity
32-
import kotlinx.coroutines.Dispatchers
32+
import kotlinx.coroutines.flow.first
3333
import kotlinx.coroutines.ExperimentalCoroutinesApi
34-
import kotlinx.coroutines.test.StandardTestDispatcher
35-
import kotlinx.coroutines.test.TestCoroutineScheduler
36-
import kotlinx.coroutines.test.TestDispatcher
37-
import kotlinx.coroutines.test.resetMain
3834
import kotlinx.coroutines.test.runTest
39-
import kotlinx.coroutines.test.setMain
40-
import kotlin.test.AfterTest
4135
import kotlin.test.BeforeTest
4236
import kotlin.test.Test
4337
import kotlin.test.assertEquals
@@ -207,6 +201,79 @@ class AccountsDAOTest : GlobalDBBaseTest() {
207201
assertEquals(null, result)
208202
}
209203

204+
@Test
205+
fun whenUpdatingPersistentWebSocketStatus_thenStatusIsUpdated() = runTest {
206+
val account = VALID_ACCOUNT
207+
globalDatabaseBuilder.accountsDAO.insertOrReplace(account.info.userIDEntity, account.ssoId, account.managedBy, account.serverConfigId, false)
208+
209+
// initial status false
210+
val initial = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity)
211+
assertEquals(false, initial)
212+
213+
// update to true
214+
globalDatabaseBuilder.accountsDAO.updatePersistentWebSocketStatus(account.info.userIDEntity, true)
215+
val updated = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity)
216+
assertEquals(true, updated)
217+
}
218+
219+
@Test
220+
fun whenSettingAllAccountsPersistentWebSocketEnabled_thenAllStatusesAreUpdated() = runTest {
221+
val a1 = VALID_ACCOUNT
222+
val a2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user2", "domain2"), null))
223+
val a3 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user3", "domain3"), null))
224+
225+
listOf(a1, a2, a3).forEach {
226+
globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false)
227+
}
228+
229+
globalDatabaseBuilder.accountsDAO.setAllAccountsPersistentWebSocketEnabled(true)
230+
231+
listOf(a1, a2, a3).forEach {
232+
val status = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(it.info.userIDEntity)
233+
assertEquals(true, status)
234+
}
235+
}
236+
237+
@Test
238+
fun whenGettingAllValidAccountPersistentWebSocketStatus_thenOnlyValidAccountsIncluded() = runTest {
239+
val valid1 = VALID_ACCOUNT
240+
val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userB", "domainB"), null))
241+
val invalid = INVALID_ACCOUNT
242+
243+
// insert accounts with different initial statuses
244+
globalDatabaseBuilder.accountsDAO.insertOrReplace(valid1.info.userIDEntity, valid1.ssoId, valid1.managedBy, valid1.serverConfigId, true)
245+
globalDatabaseBuilder.accountsDAO.insertOrReplace(valid2.info.userIDEntity, valid2.ssoId, valid2.managedBy, valid2.serverConfigId, false)
246+
globalDatabaseBuilder.accountsDAO.insertOrReplace(invalid.info.userIDEntity, invalid.ssoId, invalid.managedBy, invalid.serverConfigId, true)
247+
globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!)
248+
249+
val list = globalDatabaseBuilder.accountsDAO.getAllValidAccountPersistentWebSocketStatus().first()
250+
// Should contain only the two valid accounts in any order
251+
val ids = list.map { it.userIDEntity }.toSet()
252+
assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), ids)
253+
val map = list.associateBy({ it.userIDEntity }, { it.isPersistentWebSocketEnabled })
254+
assertEquals(true, map[valid1.info.userIDEntity])
255+
assertEquals(false, map[valid2.info.userIDEntity])
256+
}
257+
258+
@Test
259+
fun whenRequestingValidAccountWithServerConfigId_thenReturnMapForValidAccounts() = runTest {
260+
val valid1 = VALID_ACCOUNT
261+
val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userC", "domainC"), null))
262+
val invalid = INVALID_ACCOUNT
263+
264+
listOf(valid1, valid2, invalid).forEach {
265+
globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false)
266+
}
267+
globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!)
268+
269+
val map = globalDatabaseBuilder.accountsDAO.validAccountWithServerConfigId()
270+
// only valid1 and valid2 should be present
271+
assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), map.keys)
272+
map.values.forEach { serverConfig ->
273+
assertEquals(SERVER_CONFIG, serverConfig)
274+
}
275+
}
276+
210277
private companion object {
211278

212279
val VALID_ACCOUNT = FullAccountEntity(

logic/src/commonMain/kotlin/com/wire/kalium/logic/GlobalKaliumScope.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ import com.wire.kalium.logic.feature.user.ObserveValidAccountsUseCase
6363
import com.wire.kalium.logic.feature.user.ObserveValidAccountsUseCaseImpl
6464
import com.wire.kalium.logic.feature.user.webSocketStatus.ObservePersistentWebSocketConnectionStatusUseCase
6565
import com.wire.kalium.logic.feature.user.webSocketStatus.ObservePersistentWebSocketConnectionStatusUseCaseImpl
66+
import com.wire.kalium.logic.feature.user.webSocketStatus.SetPersistentWebSocketForAllUsersUseCase
67+
import com.wire.kalium.logic.feature.user.webSocketStatus.SetPersistentWebSocketForAllUsersUseCaseImpl
6668
import com.wire.kalium.logic.featureFlags.KaliumConfigs
6769
import com.wire.kalium.logic.sync.GlobalWorkScheduler
6870
import com.wire.kalium.logic.sync.WorkSchedulerProvider
@@ -121,6 +123,9 @@ public class GlobalKaliumScope internal constructor(
121123
public val observePersistentWebSocketConnectionStatus: ObservePersistentWebSocketConnectionStatusUseCase
122124
get() = ObservePersistentWebSocketConnectionStatusUseCaseImpl(sessionRepository)
123125

126+
public val setAllPersistentWebSocketEnabled: SetPersistentWebSocketForAllUsersUseCase
127+
get() = SetPersistentWebSocketForAllUsersUseCaseImpl(sessionRepository)
128+
124129
private val notificationTokenRepository: NotificationTokenRepository
125130
get() = NotificationTokenDataSource(globalPreferences.tokenStorage)
126131

logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ internal interface SessionRepository {
7777
suspend fun deleteSession(userId: UserId): Either<StorageFailure, Unit>
7878
suspend fun ssoId(userId: UserId): Either<StorageFailure, SsoIdEntity?>
7979
suspend fun updatePersistentWebSocketStatus(userId: UserId, isPersistentWebSocketEnabled: Boolean): Either<StorageFailure, Unit>
80+
suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either<StorageFailure, Unit>
8081
suspend fun updateSsoIdAndScimInfo(userId: UserId, ssoId: SsoId?, managedBy: ManagedByDTO?): Either<StorageFailure, Unit>
8182
suspend fun isFederated(userId: UserId): Either<StorageFailure, Boolean>
8283
suspend fun getAllValidAccountPersistentWebSocketStatus(): Either<StorageFailure, Flow<List<PersistentWebSocketStatus>>>
@@ -198,6 +199,9 @@ internal class SessionDataSource internal constructor(
198199
accountsDAO.updatePersistentWebSocketStatus(userId.toDao(), isPersistentWebSocketEnabled)
199200
}
200201

202+
override suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either<StorageFailure, Unit> =
203+
wrapStorageRequest { accountsDAO.setAllAccountsPersistentWebSocketEnabled(enabled) }
204+
201205
override suspend fun updateSsoIdAndScimInfo(
202206
userId: UserId,
203207
ssoId: SsoId?,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Wire
3+
* Copyright (C) 2025 Wire Swiss GmbH
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU General Public License
16+
* along with this program. If not, see http://www.gnu.org/licenses/.
17+
*/
18+
19+
package com.wire.kalium.logic.feature.user.webSocketStatus
20+
21+
import com.wire.kalium.common.error.CoreFailure
22+
import com.wire.kalium.common.functional.fold
23+
import com.wire.kalium.logic.data.session.SessionRepository
24+
25+
/**
26+
* This use case is responsible for setting the persistent web socket connection status for all users.
27+
*/
28+
public interface SetPersistentWebSocketForAllUsersUseCase {
29+
/**
30+
* @param enabled true if the persistent web socket connection should be enabled for all users, false otherwise
31+
*/
32+
public suspend operator fun invoke(enabled: Boolean): SetAllPersistentWebSocketEnabledResult
33+
}
34+
35+
public sealed class SetAllPersistentWebSocketEnabledResult {
36+
public data object Success : SetAllPersistentWebSocketEnabledResult()
37+
public data class Failure(val failure: CoreFailure) : SetAllPersistentWebSocketEnabledResult()
38+
}
39+
40+
internal class SetPersistentWebSocketForAllUsersUseCaseImpl(
41+
private val sessionRepository: SessionRepository
42+
) : SetPersistentWebSocketForAllUsersUseCase {
43+
override suspend operator fun invoke(enabled: Boolean): SetAllPersistentWebSocketEnabledResult =
44+
sessionRepository.setAllPersistentWebSocketEnabled(enabled).fold({
45+
SetAllPersistentWebSocketEnabledResult.Failure(it)
46+
}, {
47+
SetAllPersistentWebSocketEnabledResult.Success
48+
})
49+
}

0 commit comments

Comments
 (0)