Skip to content

Commit bff067d

Browse files
committed
Added callback support and as async (requires change to libsignal-protocol-dotnet for it)
Add GetStrippedMessage that takes a SessionVersion instead and use it so callback support on new sessions works
1 parent b3f9c3c commit bff067d

File tree

1 file changed

+109
-59
lines changed

1 file changed

+109
-59
lines changed

libsignal-service-dotnet/crypto/SignalServiceCipher.cs

Lines changed: 109 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using libsignalservicedotnet.crypto;
1616
using System;
1717
using System.Collections.Generic;
18+
using System.Threading.Tasks;
1819
using System.Linq;
1920

2021
namespace libsignalservice.crypto
@@ -72,77 +73,108 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, Unidentifi
7273
/// Decrypt a received <see cref="SignalServiceEnvelope"/>
7374
/// </summary>
7475
/// <param name="envelope">The received SignalServiceEnvelope</param>
76+
/// <param name="callback">Optional callback to call during the decrypt process before it is acked</param>
7577
/// <returns>a decrypted SignalServiceContent</returns>
76-
public SignalServiceContent? Decrypt(SignalServiceEnvelope envelope)
78+
public async Task<SignalServiceContent?> Decrypt(SignalServiceEnvelope envelope, Func<SignalServiceContent?, Task> callback = null)
7779
{
80+
Func<Plaintext, Task> callback_func = null;
81+
if (callback != null)
82+
{
83+
callback_func = async (data) => await callback(await DecryptComplete(envelope, data));
84+
}
7885
try
7986
{
87+
Plaintext plaintext = null;
8088
if (envelope.HasLegacyMessage())
8189
{
82-
Plaintext plaintext = Decrypt(envelope, envelope.GetLegacyMessage());
83-
DataMessage message = DataMessage.Parser.ParseFrom(plaintext.Data);
90+
plaintext = await Decrypt(envelope, envelope.GetLegacyMessage(), callback_func);
91+
}
92+
else if (envelope.HasContent())
93+
{
94+
plaintext = await Decrypt(envelope, envelope.GetContent(), callback_func);
95+
}
96+
if (callback_func != null)
97+
{
98+
return null;
99+
}
100+
return await DecryptComplete(envelope, plaintext);
101+
}
102+
catch (InvalidProtocolBufferException e)
103+
{
104+
throw new InvalidMessageException(e);
105+
}
106+
}
107+
private async Task<SignalServiceContent> DecryptComplete(SignalServiceEnvelope envelope, Plaintext plaintext)
108+
{
109+
if (envelope.HasLegacyMessage())
110+
{
111+
DataMessage message = DataMessage.Parser.ParseFrom(plaintext.Data);
112+
return new SignalServiceContent(plaintext.Metadata.Sender,
113+
plaintext.Metadata.SenderDevice,
114+
plaintext.Metadata.Timestamp,
115+
plaintext.Metadata.NeedsReceipt)
116+
{
117+
Message = CreateSignalServiceMessage(plaintext.Metadata, message)
118+
};
119+
}
120+
else if (envelope.HasContent())
121+
{
122+
Content message = Content.Parser.ParseFrom(plaintext.Data);
123+
if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
124+
{
84125
return new SignalServiceContent(plaintext.Metadata.Sender,
85-
plaintext.Metadata.SenderDevice,
86-
plaintext.Metadata.Timestamp,
87-
plaintext.Metadata.NeedsReceipt)
126+
plaintext.Metadata.SenderDevice,
127+
plaintext.Metadata.Timestamp,
128+
plaintext.Metadata.NeedsReceipt)
88129
{
89-
Message = CreateSignalServiceMessage(plaintext.Metadata, message)
130+
Message = CreateSignalServiceMessage(plaintext.Metadata, message.DataMessage)
90131
};
91132
}
92-
else if (envelope.HasContent())
133+
else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage)
93134
{
94-
Plaintext plaintext = Decrypt(envelope, envelope.Envelope.Content.ToByteArray());
95-
Content message = Content.Parser.ParseFrom(plaintext.Data);
96-
if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
97-
{
98-
return new SignalServiceContent(plaintext.Metadata.Sender,
99-
plaintext.Metadata.SenderDevice,
100-
plaintext.Metadata.Timestamp,
101-
plaintext.Metadata.NeedsReceipt)
102-
{
103-
Message = CreateSignalServiceMessage(plaintext.Metadata, message.DataMessage)
104-
};
105-
}
106-
else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage)
135+
return new SignalServiceContent(plaintext.Metadata.Sender,
136+
plaintext.Metadata.SenderDevice,
137+
plaintext.Metadata.Timestamp,
138+
plaintext.Metadata.NeedsReceipt)
107139
{
108-
return new SignalServiceContent(plaintext.Metadata.Sender,
109-
plaintext.Metadata.SenderDevice,
110-
plaintext.Metadata.Timestamp,
111-
plaintext.Metadata.NeedsReceipt)
112-
{
113-
SynchronizeMessage = CreateSynchronizeMessage(plaintext.Metadata, message.SyncMessage)
114-
};
115-
}
116-
else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
140+
SynchronizeMessage = CreateSynchronizeMessage(plaintext.Metadata, message.SyncMessage)
141+
};
142+
}
143+
else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
144+
{
145+
return new SignalServiceContent(plaintext.Metadata.Sender,
146+
plaintext.Metadata.SenderDevice,
147+
plaintext.Metadata.Timestamp,
148+
plaintext.Metadata.NeedsReceipt)
117149
{
118-
return new SignalServiceContent(plaintext.Metadata.Sender,
119-
plaintext.Metadata.SenderDevice,
120-
plaintext.Metadata.Timestamp,
121-
plaintext.Metadata.NeedsReceipt)
122-
{
123-
CallMessage = CreateCallMessage(message.CallMessage)
124-
};
125-
}
126-
else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
150+
CallMessage = CreateCallMessage(message.CallMessage)
151+
};
152+
}
153+
else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
154+
{
155+
return new SignalServiceContent(plaintext.Metadata.Sender,
156+
plaintext.Metadata.SenderDevice,
157+
plaintext.Metadata.Timestamp,
158+
plaintext.Metadata.NeedsReceipt)
127159
{
128-
return new SignalServiceContent(plaintext.Metadata.Sender,
129-
plaintext.Metadata.SenderDevice,
130-
plaintext.Metadata.Timestamp,
131-
plaintext.Metadata.NeedsReceipt)
132-
{
133-
ReadMessage = CreateReceiptMessage(plaintext.Metadata, message.ReceiptMessage)
134-
};
135-
}
160+
ReadMessage = CreateReceiptMessage(plaintext.Metadata, message.ReceiptMessage)
161+
};
136162
}
137-
return null;
138163
}
139-
catch (InvalidProtocolBufferException e)
164+
return null;
165+
}
166+
private class DecryptionCallbackHandler : DecryptionCallback
167+
{
168+
public Task handlePlaintext(byte[] data, uint sessionVersion)
140169
{
141-
throw new InvalidMetadataMessageException(e);
170+
data = GetStrippedMessage(sessionVersion, data);
171+
return callback(new Plaintext(metadata, data));
142172
}
173+
public SessionCipher sessionCipher;
174+
public Metadata metadata;
175+
public Func<Plaintext, Task> callback;
143176
}
144-
145-
private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
177+
private async Task<Plaintext> Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext, Func<Plaintext, Task> callback = null)
146178
{
147179
try
148180
{
@@ -153,15 +185,27 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
153185
byte[] paddedMessage;
154186
Metadata metadata;
155187
uint sessionVersion;
156-
188+
DecryptionCallbackHandler callback_handler = null;
189+
if (callback != null)
190+
callback_handler = new DecryptionCallbackHandler { callback = callback, sessionCipher = sessionCipher };
157191
if (envelope.IsPreKeySignalMessage())
158192
{
159-
paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext));
160193
metadata = new Metadata(envelope.GetSource(), envelope.GetSourceDevice(), envelope.GetTimestamp(), false);
194+
if (callback_handler != null)
195+
{
196+
await sessionCipher.decrypt(new PreKeySignalMessage(ciphertext), callback_handler);
197+
return null;
198+
}
199+
paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext));
161200
sessionVersion = sessionCipher.getSessionVersion();
162201
}
163202
else if (envelope.IsSignalMessage())
164203
{
204+
if (callback_handler != null)
205+
{
206+
await sessionCipher.decrypt(new SignalMessage(ciphertext), callback_handler);
207+
return null;
208+
}
165209
paddedMessage = sessionCipher.decrypt(new SignalMessage(ciphertext));
166210
metadata = new Metadata(envelope.GetSource(), envelope.GetSourceDevice(), envelope.GetTimestamp(), false);
167211
sessionVersion = sessionCipher.getSessionVersion();
@@ -170,16 +214,14 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
170214
{
171215
var results = sealedSessionCipher.Decrypt(CertificateValidator, ciphertext, (long)envelope.Envelope.ServerTimestamp);
172216
paddedMessage = results.Item2;
173-
metadata = new Metadata(results.Item1.Name, (int) results.Item1.DeviceId, (long) envelope.Envelope.Timestamp, true);
174-
sessionVersion = (uint) sealedSessionCipher.GetSessionVersion(new SignalProtocolAddress(metadata.Sender, (uint) metadata.SenderDevice));
217+
metadata = new Metadata(results.Item1.Name, (int)results.Item1.DeviceId, (long)envelope.Envelope.Timestamp, true);
218+
sessionVersion = (uint)sealedSessionCipher.GetSessionVersion(new SignalProtocolAddress(metadata.Sender, (uint)metadata.SenderDevice));
175219
}
176220
else
177221
{
178222
throw new InvalidMessageException("Unknown type: " + envelope.GetEnvelopeType() + " from " + envelope.GetSource());
179223
}
180-
181-
PushTransportDetails transportDetails = new PushTransportDetails(sessionVersion);
182-
byte[] data = transportDetails.GetStrippedPaddingMessageBody(paddedMessage);
224+
var data = GetStrippedMessage(sessionVersion, paddedMessage);
183225
return new Plaintext(metadata, data);
184226
}
185227
catch (DuplicateMessageException e)
@@ -214,7 +256,15 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
214256
{
215257
throw new ProtocolNoSessionException(e, envelope.GetSource(), envelope.GetSourceDevice());
216258
}
259+
217260
}
261+
private static byte[] GetStrippedMessage(uint sessionVersion, byte[] paddedMessage)
262+
{
263+
PushTransportDetails transportDetails = new PushTransportDetails(sessionVersion);
264+
byte[] data = transportDetails.GetStrippedPaddingMessageBody(paddedMessage);
265+
return data;
266+
}
267+
218268

219269
private SignalServiceDataMessage CreateSignalServiceMessage(Metadata metadata, DataMessage content)
220270
{

0 commit comments

Comments
 (0)