diff --git a/libsignal-service-dotnet/crypto/SignalServiceCipher.cs b/libsignal-service-dotnet/crypto/SignalServiceCipher.cs
index 58bd30b..050091e 100644
--- a/libsignal-service-dotnet/crypto/SignalServiceCipher.cs
+++ b/libsignal-service-dotnet/crypto/SignalServiceCipher.cs
@@ -15,6 +15,7 @@
using libsignalservicedotnet.crypto;
using System;
using System.Collections.Generic;
+using System.Threading.Tasks;
using System.Linq;
namespace libsignalservice.crypto
@@ -72,77 +73,108 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, Unidentifi
/// Decrypt a received
///
/// The received SignalServiceEnvelope
+ /// Optional callback to call during the decrypt process before it is acked
/// a decrypted SignalServiceContent
- public SignalServiceContent? Decrypt(SignalServiceEnvelope envelope)
+ public async Task Decrypt(SignalServiceEnvelope envelope, Func callback = null)
{
+ Func callback_func = null;
+ if (callback != null)
+ {
+ callback_func = async (data) => await callback(await DecryptComplete(envelope, data));
+ }
try
{
+ Plaintext plaintext = null;
if (envelope.HasLegacyMessage())
{
- Plaintext plaintext = Decrypt(envelope, envelope.GetLegacyMessage());
- DataMessage message = DataMessage.Parser.ParseFrom(plaintext.Data);
+ plaintext = await Decrypt(envelope, envelope.GetLegacyMessage(), callback_func);
+ }
+ else if (envelope.HasContent())
+ {
+ plaintext = await Decrypt(envelope, envelope.GetContent(), callback_func);
+ }
+ if (callback_func != null)
+ {
+ return null;
+ }
+ return await DecryptComplete(envelope, plaintext);
+ }
+ catch (InvalidProtocolBufferException e)
+ {
+ throw new InvalidMessageException(e);
+ }
+ }
+ private async Task DecryptComplete(SignalServiceEnvelope envelope, Plaintext plaintext)
+ {
+ if (envelope.HasLegacyMessage())
+ {
+ DataMessage message = DataMessage.Parser.ParseFrom(plaintext.Data);
+ return new SignalServiceContent(plaintext.Metadata.Sender,
+ plaintext.Metadata.SenderDevice,
+ plaintext.Metadata.Timestamp,
+ plaintext.Metadata.NeedsReceipt)
+ {
+ Message = CreateSignalServiceMessage(plaintext.Metadata, message)
+ };
+ }
+ else if (envelope.HasContent())
+ {
+ Content message = Content.Parser.ParseFrom(plaintext.Data);
+ if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
+ {
return new SignalServiceContent(plaintext.Metadata.Sender,
- plaintext.Metadata.SenderDevice,
- plaintext.Metadata.Timestamp,
- plaintext.Metadata.NeedsReceipt)
+ plaintext.Metadata.SenderDevice,
+ plaintext.Metadata.Timestamp,
+ plaintext.Metadata.NeedsReceipt)
{
- Message = CreateSignalServiceMessage(plaintext.Metadata, message)
+ Message = CreateSignalServiceMessage(plaintext.Metadata, message.DataMessage)
};
}
- else if (envelope.HasContent())
+ else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage)
{
- Plaintext plaintext = Decrypt(envelope, envelope.Envelope.Content.ToByteArray());
- Content message = Content.Parser.ParseFrom(plaintext.Data);
- if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
- {
- return new SignalServiceContent(plaintext.Metadata.Sender,
- plaintext.Metadata.SenderDevice,
- plaintext.Metadata.Timestamp,
- plaintext.Metadata.NeedsReceipt)
- {
- Message = CreateSignalServiceMessage(plaintext.Metadata, message.DataMessage)
- };
- }
- else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage)
+ return new SignalServiceContent(plaintext.Metadata.Sender,
+ plaintext.Metadata.SenderDevice,
+ plaintext.Metadata.Timestamp,
+ plaintext.Metadata.NeedsReceipt)
{
- return new SignalServiceContent(plaintext.Metadata.Sender,
- plaintext.Metadata.SenderDevice,
- plaintext.Metadata.Timestamp,
- plaintext.Metadata.NeedsReceipt)
- {
- SynchronizeMessage = CreateSynchronizeMessage(plaintext.Metadata, message.SyncMessage)
- };
- }
- else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
+ SynchronizeMessage = CreateSynchronizeMessage(plaintext.Metadata, message.SyncMessage)
+ };
+ }
+ else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
+ {
+ return new SignalServiceContent(plaintext.Metadata.Sender,
+ plaintext.Metadata.SenderDevice,
+ plaintext.Metadata.Timestamp,
+ plaintext.Metadata.NeedsReceipt)
{
- return new SignalServiceContent(plaintext.Metadata.Sender,
- plaintext.Metadata.SenderDevice,
- plaintext.Metadata.Timestamp,
- plaintext.Metadata.NeedsReceipt)
- {
- CallMessage = CreateCallMessage(message.CallMessage)
- };
- }
- else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
+ CallMessage = CreateCallMessage(message.CallMessage)
+ };
+ }
+ else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
+ {
+ return new SignalServiceContent(plaintext.Metadata.Sender,
+ plaintext.Metadata.SenderDevice,
+ plaintext.Metadata.Timestamp,
+ plaintext.Metadata.NeedsReceipt)
{
- return new SignalServiceContent(plaintext.Metadata.Sender,
- plaintext.Metadata.SenderDevice,
- plaintext.Metadata.Timestamp,
- plaintext.Metadata.NeedsReceipt)
- {
- ReadMessage = CreateReceiptMessage(plaintext.Metadata, message.ReceiptMessage)
- };
- }
+ ReadMessage = CreateReceiptMessage(plaintext.Metadata, message.ReceiptMessage)
+ };
}
- return null;
}
- catch (InvalidProtocolBufferException e)
+ return null;
+ }
+ private class DecryptionCallbackHandler : DecryptionCallback
+ {
+ public Task handlePlaintext(byte[] data, uint sessionVersion)
{
- throw new InvalidMetadataMessageException(e);
+ data = GetStrippedMessage(sessionVersion, data);
+ return callback(new Plaintext(metadata, data));
}
+ public SessionCipher sessionCipher;
+ public Metadata metadata;
+ public Func callback;
}
-
- private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
+ private async Task Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext, Func callback = null)
{
try
{
@@ -153,15 +185,27 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
byte[] paddedMessage;
Metadata metadata;
uint sessionVersion;
-
+ DecryptionCallbackHandler callback_handler = null;
+ if (callback != null)
+ callback_handler = new DecryptionCallbackHandler { callback = callback, sessionCipher = sessionCipher };
if (envelope.IsPreKeySignalMessage())
{
- paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext));
metadata = new Metadata(envelope.GetSource(), envelope.GetSourceDevice(), envelope.GetTimestamp(), false);
+ if (callback_handler != null)
+ {
+ await sessionCipher.decrypt(new PreKeySignalMessage(ciphertext), callback_handler);
+ return null;
+ }
+ paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext));
sessionVersion = sessionCipher.getSessionVersion();
}
else if (envelope.IsSignalMessage())
{
+ if (callback_handler != null)
+ {
+ await sessionCipher.decrypt(new SignalMessage(ciphertext), callback_handler);
+ return null;
+ }
paddedMessage = sessionCipher.decrypt(new SignalMessage(ciphertext));
metadata = new Metadata(envelope.GetSource(), envelope.GetSourceDevice(), envelope.GetTimestamp(), false);
sessionVersion = sessionCipher.getSessionVersion();
@@ -170,16 +214,14 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
{
var results = sealedSessionCipher.Decrypt(CertificateValidator, ciphertext, (long)envelope.Envelope.ServerTimestamp);
paddedMessage = results.Item2;
- metadata = new Metadata(results.Item1.Name, (int) results.Item1.DeviceId, (long) envelope.Envelope.Timestamp, true);
- sessionVersion = (uint) sealedSessionCipher.GetSessionVersion(new SignalProtocolAddress(metadata.Sender, (uint) metadata.SenderDevice));
+ metadata = new Metadata(results.Item1.Name, (int)results.Item1.DeviceId, (long)envelope.Envelope.Timestamp, true);
+ sessionVersion = (uint)sealedSessionCipher.GetSessionVersion(new SignalProtocolAddress(metadata.Sender, (uint)metadata.SenderDevice));
}
else
{
throw new InvalidMessageException("Unknown type: " + envelope.GetEnvelopeType() + " from " + envelope.GetSource());
}
-
- PushTransportDetails transportDetails = new PushTransportDetails(sessionVersion);
- byte[] data = transportDetails.GetStrippedPaddingMessageBody(paddedMessage);
+ var data = GetStrippedMessage(sessionVersion, paddedMessage);
return new Plaintext(metadata, data);
}
catch (DuplicateMessageException e)
@@ -214,7 +256,15 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
{
throw new ProtocolNoSessionException(e, envelope.GetSource(), envelope.GetSourceDevice());
}
+
}
+ private static byte[] GetStrippedMessage(uint sessionVersion, byte[] paddedMessage)
+ {
+ PushTransportDetails transportDetails = new PushTransportDetails(sessionVersion);
+ byte[] data = transportDetails.GetStrippedPaddingMessageBody(paddedMessage);
+ return data;
+ }
+
private SignalServiceDataMessage CreateSignalServiceMessage(Metadata metadata, DataMessage content)
{