Skip to content

Commit 20c958d

Browse files
johnxnguyennetbe
andauthored
fix: repair faulty removal keys - WPB-22447 🍒 (#4042)
Co-authored-by: François Benaiteau <[email protected]>
1 parent 2fef253 commit 20c958d

File tree

28 files changed

+721
-25
lines changed

28 files changed

+721
-25
lines changed

WireDomain/Sources/WireDomain/Components/ClientSessionComponent.swift

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ public final class ClientSessionComponent {
6565
private let coreCryptoProvider: any CoreCryptoProviderProtocol
6666
private let completionHandlers: CompletionHandlers
6767

68+
private let faultyMLSRemovalKeysByDomain: [String: [String]]
69+
6870
public init(
6971
selfUserID: UUID,
7072
selfClientID: String,
@@ -81,7 +83,8 @@ public final class ClientSessionComponent {
8183
mlsDecryptionService: any MLSDecryptionServiceInterface,
8284
proteusService: any ProteusServiceInterface,
8385
coreCryptoProvider: any CoreCryptoProviderProtocol,
84-
completionHandlers: CompletionHandlers
86+
completionHandlers: CompletionHandlers,
87+
faultyMLSRemovalKeysByDomain: [String: [String]]
8588
) {
8689
self.selfUserID = selfUserID
8790
self.selfClientID = selfClientID
@@ -99,6 +102,7 @@ public final class ClientSessionComponent {
99102
self.isMLSEnabled = isMLSEnabled
100103
self.coreCryptoProvider = coreCryptoProvider
101104
self.completionHandlers = completionHandlers
105+
self.faultyMLSRemovalKeysByDomain = faultyMLSRemovalKeysByDomain
102106
}
103107

104108
public private(set) lazy var authenticationManager = AuthenticationManager(
@@ -793,6 +797,15 @@ public final class ClientSessionComponent {
793797
userID: selfUserID
794798
)
795799

800+
public lazy var repairFaultyRemovalKeysUsecase = RepairRemovalKeysUseCase(
801+
faultyMLSRemovalKeysByDomain: faultyMLSRemovalKeysByDomain,
802+
context: syncContext,
803+
mlsService: mlsService,
804+
conversationsAPI: conversationsAPI,
805+
conversationLocalStore: conversationLocalStore,
806+
initiateResetUseCase: initiateResetMLSConversationUseCase
807+
)
808+
796809
public lazy var initiateResetMLSConversationUseCase = InitiateResetMLSConversationUseCase(
797810
api: mlsAPI,
798811
mlsService: mlsService,

WireDomain/Sources/WireDomain/Components/UserSessionComponent.swift

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public final class UserSessionComponent {
4343
private let proteusService: any ProteusServiceInterface
4444
private let coreCryptoProvider: any CoreCryptoProviderProtocol
4545

46+
private let faultyMLSRemovalKeysByDomain: [String: [String]]
47+
4648
public init(
4749
currentBuildNumber: String,
4850
selfUserID: UUID,
@@ -59,7 +61,8 @@ public final class UserSessionComponent {
5961
mlsService: any MLSServiceInterface,
6062
mlsDecryptionService: any MLSDecryptionServiceInterface,
6163
proteusService: any ProteusServiceInterface,
62-
coreCryptoProvider: any CoreCryptoProviderProtocol
64+
coreCryptoProvider: any CoreCryptoProviderProtocol,
65+
faultyMLSRemovalKeysByDomain: [String: [String]]
6366
) {
6467
self.currentBuildNumber = currentBuildNumber
6568
self.selfUserID = selfUserID
@@ -77,6 +80,7 @@ public final class UserSessionComponent {
7780
self.proteusService = proteusService
7881
self.coreCryptoProvider = coreCryptoProvider
7982
self.sharedContainerURL = sharedContainerURL
83+
self.faultyMLSRemovalKeysByDomain = faultyMLSRemovalKeysByDomain
8084
}
8185

8286
private let cookieStorage: any CookieStorageProtocol
@@ -103,7 +107,8 @@ public final class UserSessionComponent {
103107
mlsDecryptionService: mlsDecryptionService,
104108
proteusService: proteusService,
105109
coreCryptoProvider: coreCryptoProvider,
106-
completionHandlers: completionHandlers
110+
completionHandlers: completionHandlers,
111+
faultyMLSRemovalKeysByDomain: faultyMLSRemovalKeysByDomain
107112
)
108113
}
109114

WireDomain/Sources/WireDomain/Helpers/InitiateResetMLSConversationUseCase.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import WireDataModel
2121
import WireLogging
2222
import WireNetwork
2323

24+
// sourcery: AutoMockable
2425
public protocol InitiateResetMLSConversationUseCaseProtocol {
2526
func invoke(groupID: WireDataModel.MLSGroupID, epoch: UInt64) async
2627
}

WireDomain/Sources/WireDomain/Repositories/Conversations/LocalStore/ConversationLocalStore.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,16 @@ public final class ConversationLocalStore: ConversationLocalStoreProtocol {
201201
}
202202
}
203203

204+
public func fetchAllMLSConversations(domain: String?) async throws -> [ZMConversation] {
205+
try await context.perform { [context] in
206+
try ZMConversation.fetchConversationsWithMLSGroupStatus(
207+
mlsGroupStatus: .ready,
208+
domain: domain,
209+
in: context
210+
)
211+
}
212+
}
213+
204214
public func fetchConversation(
205215
id: UUID,
206216
domain: String?

WireDomain/Sources/WireDomain/Repositories/Conversations/Protocols/ConversationLocalStoreProtocol.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,21 @@ public protocol ConversationLocalStoreProtocol {
8484
mlsGroupID: MLSGroupID
8585
) async
8686

87+
/// Fetches all MLS conversations that are ready.
88+
///
89+
/// This method retrieves all conversations that have MLS group IDs and are in a ready state,
90+
/// optionally filtered by domain.
91+
///
92+
/// - Parameter domain: The domain to filter conversations by. If `nil`, fetches conversations
93+
/// from all domains.
94+
///
95+
/// - Returns: An array of `ZMConversation` objects that are MLS-ready. Returns an empty array
96+
/// if no conversations match the criteria.
97+
///
98+
/// - Throws: An error if the fetch operation fails.
99+
100+
func fetchAllMLSConversations(domain: String?) async throws -> [ZMConversation]
101+
87102
/// Fetches a MLS conversation locally.
88103
///
89104
/// - parameters:
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
//
2+
// Wire
3+
// Copyright (C) 2025 Wire Swiss GmbH
4+
//
5+
// This program is free software: you can redistribute it and/or modify
6+
// it under the terms of the GNU General Public License as published by
7+
// the Free Software Foundation, either version 3 of the License, or
8+
// (at your option) any later version.
9+
//
10+
// This program is distributed in the hope that it will be useful,
11+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
// GNU General Public License for more details.
14+
//
15+
// You should have received a copy of the GNU General Public License
16+
// along with this program. If not, see http://www.gnu.org/licenses/.
17+
//
18+
19+
import WireDataModel
20+
import WireLogging
21+
import WireNetwork
22+
23+
// sourcery: AutoMockable
24+
/// Repairs conversations with faulty removal keys
25+
public protocol RepairRemovalKeysUseCaseProtocol {
26+
func invoke() async throws
27+
}
28+
29+
public struct RepairRemovalKeysUseCase: RepairRemovalKeysUseCaseProtocol {
30+
31+
let faultyMLSRemovalKeysByDomain: [String: [String]]
32+
33+
private let context: NSManagedObjectContext
34+
private let mlsService: MLSServiceInterface
35+
private let conversationsAPI: ConversationsAPI
36+
private let conversationLocalStore: ConversationLocalStoreProtocol
37+
private let initiateResetUseCase: InitiateResetMLSConversationUseCaseProtocol
38+
39+
init(
40+
faultyMLSRemovalKeysByDomain: [String: [String]],
41+
context: NSManagedObjectContext,
42+
mlsService: MLSServiceInterface,
43+
conversationsAPI: ConversationsAPI,
44+
conversationLocalStore: ConversationLocalStoreProtocol,
45+
initiateResetUseCase: InitiateResetMLSConversationUseCaseProtocol
46+
) {
47+
self.faultyMLSRemovalKeysByDomain = faultyMLSRemovalKeysByDomain
48+
self.context = context
49+
self.mlsService = mlsService
50+
self.conversationsAPI = conversationsAPI
51+
self.conversationLocalStore = conversationLocalStore
52+
self.initiateResetUseCase = initiateResetUseCase
53+
}
54+
55+
public func invoke() async throws {
56+
WireLogger.mls.info(
57+
"initiating repair of faulty removal keys",
58+
attributes: .safePublic
59+
)
60+
61+
guard !faultyMLSRemovalKeysByDomain.isEmpty else {
62+
WireLogger.mls.info(
63+
"no faulty removal keys to repair, aborting",
64+
attributes: .safePublic
65+
)
66+
return
67+
}
68+
69+
// Process each domain
70+
for (domain, faultyKeyHexStrings) in faultyMLSRemovalKeysByDomain {
71+
try await processDomain(
72+
domain: domain,
73+
faultyKeyHexStrings: faultyKeyHexStrings
74+
)
75+
}
76+
}
77+
78+
// MARK: - Private
79+
80+
private func processDomain(
81+
domain: String,
82+
faultyKeyHexStrings: [String]
83+
) async throws {
84+
WireLogger.mls.info(
85+
"checking domain for \(faultyKeyHexStrings.count) faulty key(s)",
86+
attributes: .safePublic
87+
)
88+
89+
// Convert hex strings to Data
90+
let faultyKeyDataList = faultyKeyHexStrings.compactMap(Data.init(hexString:))
91+
guard faultyKeyDataList.count == faultyKeyHexStrings.count else {
92+
WireLogger.mls.error(
93+
"failed to decode some faulty removal key hex strings",
94+
attributes: .safePublic
95+
)
96+
return
97+
}
98+
99+
let allMLSConversations = try await conversationLocalStore.fetchAllMLSConversations(
100+
domain: domain
101+
)
102+
103+
// Find faulty conversations for this domain
104+
let faultyConversations = await findFaultyConversations(
105+
in: allMLSConversations,
106+
faultyKeys: faultyKeyDataList
107+
)
108+
109+
WireLogger.mls.info(
110+
"detected \(faultyConversations.count)/\(allMLSConversations.count) affected conversations",
111+
attributes: .safePublic
112+
)
113+
114+
// Repair each faulty conversation in parallel
115+
await withTaskGroup(of: Void.self) { group in
116+
for (groupID, qualifiedID) in faultyConversations {
117+
group.addTask {
118+
await repairConversation(
119+
groupID: groupID,
120+
qualifiedID: qualifiedID
121+
)
122+
}
123+
}
124+
}
125+
}
126+
127+
private func findFaultyConversations(
128+
in conversations: [ZMConversation],
129+
faultyKeys: [Data]
130+
) async -> [(MLSGroupID, WireDataModel.QualifiedID)] {
131+
var faultyConversations: [(MLSGroupID, WireDataModel.QualifiedID)] = []
132+
133+
for conversation in conversations {
134+
let (groupID, qualifiedID) = await context.perform {
135+
(conversation.mlsGroupID, conversation.qualifiedID)
136+
}
137+
138+
guard let groupID, let qualifiedID else {
139+
continue
140+
}
141+
142+
let currentRemovalKey: Data
143+
do {
144+
currentRemovalKey = try await mlsService.externalSenderKey(groupID: groupID)
145+
} catch {
146+
WireLogger.mls.error(
147+
"failed to get current removal key for a group, skipping: \(String(describing: error))",
148+
attributes: .safePublic
149+
)
150+
continue
151+
}
152+
153+
// Check if the current removal key matches any of the faulty keys
154+
if faultyKeys.contains(currentRemovalKey) {
155+
faultyConversations.append((
156+
groupID,
157+
qualifiedID
158+
))
159+
}
160+
}
161+
162+
return faultyConversations
163+
}
164+
165+
private func repairConversation(
166+
groupID: MLSGroupID,
167+
qualifiedID: WireDataModel.QualifiedID
168+
) async {
169+
let remoteConversation: WireNetwork.Conversation?
170+
do {
171+
remoteConversation = try await conversationsAPI.getConversations(
172+
for: [qualifiedID.toAPIModel()]
173+
).found.first
174+
} catch {
175+
WireLogger.mls.error(
176+
"failed to get epoch for a group, skipping: \(String(describing: error))",
177+
attributes: .safePublic, [.conversationId: qualifiedID.safeForLoggingDescription]
178+
)
179+
return
180+
}
181+
182+
guard let remoteConversation else {
183+
WireLogger.mls.error(
184+
"remote conversation for a group not found, skipping",
185+
attributes: .safePublic, [.conversationId: qualifiedID.safeForLoggingDescription]
186+
)
187+
return
188+
}
189+
190+
WireLogger.mls.info(
191+
"initiating reset for faulty conversation: \(qualifiedID)",
192+
attributes: .safePublic, [.conversationId: qualifiedID.safeForLoggingDescription]
193+
)
194+
195+
let epoch = UInt64(remoteConversation.epoch ?? 0)
196+
await initiateResetUseCase.invoke(groupID: groupID, epoch: epoch)
197+
}
198+
199+
}

0 commit comments

Comments
 (0)