15
15
using libsignalservicedotnet . crypto ;
16
16
using System ;
17
17
using System . Collections . Generic ;
18
+ using System . Threading . Tasks ;
18
19
using System . Linq ;
19
20
20
21
namespace libsignalservice . crypto
@@ -72,77 +73,108 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, Unidentifi
72
73
/// Decrypt a received <see cref="SignalServiceEnvelope"/>
73
74
/// </summary>
74
75
/// <param name="envelope">The received SignalServiceEnvelope</param>
76
+ /// <param name="callback">Optional callback to call during the decrypt process before it is acked</param>
75
77
/// <returns>a decrypted SignalServiceContent</returns>
76
- public SignalServiceContent ? Decrypt ( SignalServiceEnvelope envelope )
78
+ public async Task < SignalServiceContent ? > Decrypt ( SignalServiceEnvelope envelope , Func < SignalServiceContent ? , Task > callback = null )
77
79
{
80
+ Func < Plaintext , Task > callback_func = null ;
81
+ if ( callback != null )
82
+ {
83
+ callback_func = async ( data ) => await callback ( await DecryptComplete ( envelope , data ) ) ;
84
+ }
78
85
try
79
86
{
87
+ Plaintext plaintext = null ;
80
88
if ( envelope . HasLegacyMessage ( ) )
81
89
{
82
- Plaintext plaintext = Decrypt ( envelope , envelope . GetLegacyMessage ( ) ) ;
83
- DataMessage message = DataMessage . Parser . ParseFrom ( plaintext . Data ) ;
90
+ plaintext = await Decrypt ( envelope , envelope . GetLegacyMessage ( ) , callback_func ) ;
91
+ }
92
+ else if ( envelope . HasContent ( ) )
93
+ {
94
+ plaintext = await Decrypt ( envelope , envelope . GetContent ( ) , callback_func ) ;
95
+ }
96
+ if ( callback_func != null )
97
+ {
98
+ return null ;
99
+ }
100
+ return await DecryptComplete ( envelope , plaintext ) ;
101
+ }
102
+ catch ( InvalidProtocolBufferException e )
103
+ {
104
+ throw new InvalidMessageException ( e ) ;
105
+ }
106
+ }
107
+ private async Task < SignalServiceContent > DecryptComplete ( SignalServiceEnvelope envelope , Plaintext plaintext )
108
+ {
109
+ if ( envelope . HasLegacyMessage ( ) )
110
+ {
111
+ DataMessage message = DataMessage . Parser . ParseFrom ( plaintext . Data ) ;
112
+ return new SignalServiceContent ( plaintext . Metadata . Sender ,
113
+ plaintext . Metadata . SenderDevice ,
114
+ plaintext . Metadata . Timestamp ,
115
+ plaintext . Metadata . NeedsReceipt )
116
+ {
117
+ Message = CreateSignalServiceMessage ( plaintext . Metadata , message )
118
+ } ;
119
+ }
120
+ else if ( envelope . HasContent ( ) )
121
+ {
122
+ Content message = Content . Parser . ParseFrom ( plaintext . Data ) ;
123
+ if ( message . DataMessageOneofCase == Content . DataMessageOneofOneofCase . DataMessage )
124
+ {
84
125
return new SignalServiceContent ( plaintext . Metadata . Sender ,
85
- plaintext . Metadata . SenderDevice ,
86
- plaintext . Metadata . Timestamp ,
87
- plaintext . Metadata . NeedsReceipt )
126
+ plaintext . Metadata . SenderDevice ,
127
+ plaintext . Metadata . Timestamp ,
128
+ plaintext . Metadata . NeedsReceipt )
88
129
{
89
- Message = CreateSignalServiceMessage ( plaintext . Metadata , message )
130
+ Message = CreateSignalServiceMessage ( plaintext . Metadata , message . DataMessage )
90
131
} ;
91
132
}
92
- else if ( envelope . HasContent ( ) )
133
+ else if ( message . SyncMessageOneofCase == Content . SyncMessageOneofOneofCase . SyncMessage )
93
134
{
94
- Plaintext plaintext = Decrypt ( envelope , envelope . Envelope . Content . ToByteArray ( ) ) ;
95
- Content message = Content . Parser . ParseFrom ( plaintext . Data ) ;
96
- if ( message . DataMessageOneofCase == Content . DataMessageOneofOneofCase . DataMessage )
97
- {
98
- return new SignalServiceContent ( plaintext . Metadata . Sender ,
99
- plaintext . Metadata . SenderDevice ,
100
- plaintext . Metadata . Timestamp ,
101
- plaintext . Metadata . NeedsReceipt )
102
- {
103
- Message = CreateSignalServiceMessage ( plaintext . Metadata , message . DataMessage )
104
- } ;
105
- }
106
- else if ( message . SyncMessageOneofCase == Content . SyncMessageOneofOneofCase . SyncMessage )
135
+ return new SignalServiceContent ( plaintext . Metadata . Sender ,
136
+ plaintext . Metadata . SenderDevice ,
137
+ plaintext . Metadata . Timestamp ,
138
+ plaintext . Metadata . NeedsReceipt )
107
139
{
108
- return new SignalServiceContent ( plaintext . Metadata . Sender ,
109
- plaintext . Metadata . SenderDevice ,
110
- plaintext . Metadata . Timestamp ,
111
- plaintext . Metadata . NeedsReceipt )
112
- {
113
- SynchronizeMessage = CreateSynchronizeMessage ( plaintext . Metadata , message . SyncMessage )
114
- } ;
115
- }
116
- else if ( message . CallMessageOneofCase == Content . CallMessageOneofOneofCase . CallMessage )
140
+ SynchronizeMessage = CreateSynchronizeMessage ( plaintext . Metadata , message . SyncMessage )
141
+ } ;
142
+ }
143
+ else if ( message . CallMessageOneofCase == Content . CallMessageOneofOneofCase . CallMessage )
144
+ {
145
+ return new SignalServiceContent ( plaintext . Metadata . Sender ,
146
+ plaintext . Metadata . SenderDevice ,
147
+ plaintext . Metadata . Timestamp ,
148
+ plaintext . Metadata . NeedsReceipt )
117
149
{
118
- return new SignalServiceContent ( plaintext . Metadata . Sender ,
119
- plaintext . Metadata . SenderDevice ,
120
- plaintext . Metadata . Timestamp ,
121
- plaintext . Metadata . NeedsReceipt )
122
- {
123
- CallMessage = CreateCallMessage ( message . CallMessage )
124
- } ;
125
- }
126
- else if ( message . ReceiptMessageOneofCase == Content . ReceiptMessageOneofOneofCase . ReceiptMessage )
150
+ CallMessage = CreateCallMessage ( message . CallMessage )
151
+ } ;
152
+ }
153
+ else if ( message . ReceiptMessageOneofCase == Content . ReceiptMessageOneofOneofCase . ReceiptMessage )
154
+ {
155
+ return new SignalServiceContent ( plaintext . Metadata . Sender ,
156
+ plaintext . Metadata . SenderDevice ,
157
+ plaintext . Metadata . Timestamp ,
158
+ plaintext . Metadata . NeedsReceipt )
127
159
{
128
- return new SignalServiceContent ( plaintext . Metadata . Sender ,
129
- plaintext . Metadata . SenderDevice ,
130
- plaintext . Metadata . Timestamp ,
131
- plaintext . Metadata . NeedsReceipt )
132
- {
133
- ReadMessage = CreateReceiptMessage ( plaintext . Metadata , message . ReceiptMessage )
134
- } ;
135
- }
160
+ ReadMessage = CreateReceiptMessage ( plaintext . Metadata , message . ReceiptMessage )
161
+ } ;
136
162
}
137
- return null ;
138
163
}
139
- catch ( InvalidProtocolBufferException e )
164
+ return null ;
165
+ }
166
+ private class DecryptionCallbackHandler : DecryptionCallback
167
+ {
168
+ public Task handlePlaintext ( byte [ ] data , uint sessionVersion )
140
169
{
141
- throw new InvalidMetadataMessageException ( e ) ;
170
+ data = GetStrippedMessage ( sessionVersion , data ) ;
171
+ return callback ( new Plaintext ( metadata , data ) ) ;
142
172
}
173
+ public SessionCipher sessionCipher ;
174
+ public Metadata metadata ;
175
+ public Func < Plaintext , Task > callback ;
143
176
}
144
-
145
- private Plaintext Decrypt ( SignalServiceEnvelope envelope , byte [ ] ciphertext )
177
+ private async Task < Plaintext > Decrypt ( SignalServiceEnvelope envelope , byte [ ] ciphertext , Func < Plaintext , Task > callback = null )
146
178
{
147
179
try
148
180
{
@@ -153,15 +185,27 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
153
185
byte [ ] paddedMessage ;
154
186
Metadata metadata ;
155
187
uint sessionVersion ;
156
-
188
+ DecryptionCallbackHandler callback_handler = null ;
189
+ if ( callback != null )
190
+ callback_handler = new DecryptionCallbackHandler { callback = callback , sessionCipher = sessionCipher } ;
157
191
if ( envelope . IsPreKeySignalMessage ( ) )
158
192
{
159
- paddedMessage = sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) ) ;
160
193
metadata = new Metadata ( envelope . GetSource ( ) , envelope . GetSourceDevice ( ) , envelope . GetTimestamp ( ) , false ) ;
194
+ if ( callback_handler != null )
195
+ {
196
+ await sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) , callback_handler ) ;
197
+ return null ;
198
+ }
199
+ paddedMessage = sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) ) ;
161
200
sessionVersion = sessionCipher . getSessionVersion ( ) ;
162
201
}
163
202
else if ( envelope . IsSignalMessage ( ) )
164
203
{
204
+ if ( callback_handler != null )
205
+ {
206
+ await sessionCipher . decrypt ( new SignalMessage ( ciphertext ) , callback_handler ) ;
207
+ return null ;
208
+ }
165
209
paddedMessage = sessionCipher . decrypt ( new SignalMessage ( ciphertext ) ) ;
166
210
metadata = new Metadata ( envelope . GetSource ( ) , envelope . GetSourceDevice ( ) , envelope . GetTimestamp ( ) , false ) ;
167
211
sessionVersion = sessionCipher . getSessionVersion ( ) ;
@@ -170,16 +214,14 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
170
214
{
171
215
var results = sealedSessionCipher . Decrypt ( CertificateValidator , ciphertext , ( long ) envelope . Envelope . ServerTimestamp ) ;
172
216
paddedMessage = results . Item2 ;
173
- metadata = new Metadata ( results . Item1 . Name , ( int ) results . Item1 . DeviceId , ( long ) envelope . Envelope . Timestamp , true ) ;
174
- sessionVersion = ( uint ) sealedSessionCipher . GetSessionVersion ( new SignalProtocolAddress ( metadata . Sender , ( uint ) metadata . SenderDevice ) ) ;
217
+ metadata = new Metadata ( results . Item1 . Name , ( int ) results . Item1 . DeviceId , ( long ) envelope . Envelope . Timestamp , true ) ;
218
+ sessionVersion = ( uint ) sealedSessionCipher . GetSessionVersion ( new SignalProtocolAddress ( metadata . Sender , ( uint ) metadata . SenderDevice ) ) ;
175
219
}
176
220
else
177
221
{
178
222
throw new InvalidMessageException ( "Unknown type: " + envelope . GetEnvelopeType ( ) + " from " + envelope . GetSource ( ) ) ;
179
223
}
180
-
181
- PushTransportDetails transportDetails = new PushTransportDetails ( sessionVersion ) ;
182
- byte [ ] data = transportDetails . GetStrippedPaddingMessageBody ( paddedMessage ) ;
224
+ var data = GetStrippedMessage ( sessionVersion , paddedMessage ) ;
183
225
return new Plaintext ( metadata , data ) ;
184
226
}
185
227
catch ( DuplicateMessageException e )
@@ -214,7 +256,15 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
214
256
{
215
257
throw new ProtocolNoSessionException ( e , envelope . GetSource ( ) , envelope . GetSourceDevice ( ) ) ;
216
258
}
259
+
217
260
}
261
+ private static byte [ ] GetStrippedMessage ( uint sessionVersion , byte [ ] paddedMessage )
262
+ {
263
+ PushTransportDetails transportDetails = new PushTransportDetails ( sessionVersion ) ;
264
+ byte [ ] data = transportDetails . GetStrippedPaddingMessageBody ( paddedMessage ) ;
265
+ return data ;
266
+ }
267
+
218
268
219
269
private SignalServiceDataMessage CreateSignalServiceMessage ( Metadata metadata , DataMessage content )
220
270
{
0 commit comments