diff --git a/libsignal-service-dotnet/crypto/SignalServiceCipher.cs b/libsignal-service-dotnet/crypto/SignalServiceCipher.cs index 58bd30b..050091e 100644 --- a/libsignal-service-dotnet/crypto/SignalServiceCipher.cs +++ b/libsignal-service-dotnet/crypto/SignalServiceCipher.cs @@ -15,6 +15,7 @@ using libsignalservicedotnet.crypto; using System; using System.Collections.Generic; +using System.Threading.Tasks; using System.Linq; namespace libsignalservice.crypto @@ -72,77 +73,108 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, Unidentifi /// Decrypt a received /// /// The received SignalServiceEnvelope + /// Optional callback to call during the decrypt process before it is acked /// a decrypted SignalServiceContent - public SignalServiceContent? Decrypt(SignalServiceEnvelope envelope) + public async Task Decrypt(SignalServiceEnvelope envelope, Func callback = null) { + Func callback_func = null; + if (callback != null) + { + callback_func = async (data) => await callback(await DecryptComplete(envelope, data)); + } try { + Plaintext plaintext = null; if (envelope.HasLegacyMessage()) { - Plaintext plaintext = Decrypt(envelope, envelope.GetLegacyMessage()); - DataMessage message = DataMessage.Parser.ParseFrom(plaintext.Data); + plaintext = await Decrypt(envelope, envelope.GetLegacyMessage(), callback_func); + } + else if (envelope.HasContent()) + { + plaintext = await Decrypt(envelope, envelope.GetContent(), callback_func); + } + if (callback_func != null) + { + return null; + } + return await DecryptComplete(envelope, plaintext); + } + catch (InvalidProtocolBufferException e) + { + throw new InvalidMessageException(e); + } + } + private async Task DecryptComplete(SignalServiceEnvelope envelope, Plaintext plaintext) + { + if (envelope.HasLegacyMessage()) + { + DataMessage message = DataMessage.Parser.ParseFrom(plaintext.Data); + return new SignalServiceContent(plaintext.Metadata.Sender, + plaintext.Metadata.SenderDevice, + plaintext.Metadata.Timestamp, + plaintext.Metadata.NeedsReceipt) + { + Message = CreateSignalServiceMessage(plaintext.Metadata, message) + }; + } + else if (envelope.HasContent()) + { + Content message = Content.Parser.ParseFrom(plaintext.Data); + if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage) + { return new SignalServiceContent(plaintext.Metadata.Sender, - plaintext.Metadata.SenderDevice, - plaintext.Metadata.Timestamp, - plaintext.Metadata.NeedsReceipt) + plaintext.Metadata.SenderDevice, + plaintext.Metadata.Timestamp, + plaintext.Metadata.NeedsReceipt) { - Message = CreateSignalServiceMessage(plaintext.Metadata, message) + Message = CreateSignalServiceMessage(plaintext.Metadata, message.DataMessage) }; } - else if (envelope.HasContent()) + else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage) { - Plaintext plaintext = Decrypt(envelope, envelope.Envelope.Content.ToByteArray()); - Content message = Content.Parser.ParseFrom(plaintext.Data); - if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage) - { - return new SignalServiceContent(plaintext.Metadata.Sender, - plaintext.Metadata.SenderDevice, - plaintext.Metadata.Timestamp, - plaintext.Metadata.NeedsReceipt) - { - Message = CreateSignalServiceMessage(plaintext.Metadata, message.DataMessage) - }; - } - else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage) + return new SignalServiceContent(plaintext.Metadata.Sender, + plaintext.Metadata.SenderDevice, + plaintext.Metadata.Timestamp, + plaintext.Metadata.NeedsReceipt) { - return new SignalServiceContent(plaintext.Metadata.Sender, - plaintext.Metadata.SenderDevice, - plaintext.Metadata.Timestamp, - plaintext.Metadata.NeedsReceipt) - { - SynchronizeMessage = CreateSynchronizeMessage(plaintext.Metadata, message.SyncMessage) - }; - } - else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage) + SynchronizeMessage = CreateSynchronizeMessage(plaintext.Metadata, message.SyncMessage) + }; + } + else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage) + { + return new SignalServiceContent(plaintext.Metadata.Sender, + plaintext.Metadata.SenderDevice, + plaintext.Metadata.Timestamp, + plaintext.Metadata.NeedsReceipt) { - return new SignalServiceContent(plaintext.Metadata.Sender, - plaintext.Metadata.SenderDevice, - plaintext.Metadata.Timestamp, - plaintext.Metadata.NeedsReceipt) - { - CallMessage = CreateCallMessage(message.CallMessage) - }; - } - else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage) + CallMessage = CreateCallMessage(message.CallMessage) + }; + } + else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage) + { + return new SignalServiceContent(plaintext.Metadata.Sender, + plaintext.Metadata.SenderDevice, + plaintext.Metadata.Timestamp, + plaintext.Metadata.NeedsReceipt) { - return new SignalServiceContent(plaintext.Metadata.Sender, - plaintext.Metadata.SenderDevice, - plaintext.Metadata.Timestamp, - plaintext.Metadata.NeedsReceipt) - { - ReadMessage = CreateReceiptMessage(plaintext.Metadata, message.ReceiptMessage) - }; - } + ReadMessage = CreateReceiptMessage(plaintext.Metadata, message.ReceiptMessage) + }; } - return null; } - catch (InvalidProtocolBufferException e) + return null; + } + private class DecryptionCallbackHandler : DecryptionCallback + { + public Task handlePlaintext(byte[] data, uint sessionVersion) { - throw new InvalidMetadataMessageException(e); + data = GetStrippedMessage(sessionVersion, data); + return callback(new Plaintext(metadata, data)); } + public SessionCipher sessionCipher; + public Metadata metadata; + public Func callback; } - - private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext) + private async Task Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext, Func<Plaintext, Task> callback = null) { try { @@ -153,15 +185,27 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext) byte[] paddedMessage; Metadata metadata; uint sessionVersion; - + DecryptionCallbackHandler callback_handler = null; + if (callback != null) + callback_handler = new DecryptionCallbackHandler { callback = callback, sessionCipher = sessionCipher }; if (envelope.IsPreKeySignalMessage()) { - paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext)); metadata = new Metadata(envelope.GetSource(), envelope.GetSourceDevice(), envelope.GetTimestamp(), false); + if (callback_handler != null) + { + await sessionCipher.decrypt(new PreKeySignalMessage(ciphertext), callback_handler); + return null; + } + paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext)); sessionVersion = sessionCipher.getSessionVersion(); } else if (envelope.IsSignalMessage()) { + if (callback_handler != null) + { + await sessionCipher.decrypt(new SignalMessage(ciphertext), callback_handler); + return null; + } paddedMessage = sessionCipher.decrypt(new SignalMessage(ciphertext)); metadata = new Metadata(envelope.GetSource(), envelope.GetSourceDevice(), envelope.GetTimestamp(), false); sessionVersion = sessionCipher.getSessionVersion(); @@ -170,16 +214,14 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext) { var results = sealedSessionCipher.Decrypt(CertificateValidator, ciphertext, (long)envelope.Envelope.ServerTimestamp); paddedMessage = results.Item2; - metadata = new Metadata(results.Item1.Name, (int) results.Item1.DeviceId, (long) envelope.Envelope.Timestamp, true); - sessionVersion = (uint) sealedSessionCipher.GetSessionVersion(new SignalProtocolAddress(metadata.Sender, (uint) metadata.SenderDevice)); + metadata = new Metadata(results.Item1.Name, (int)results.Item1.DeviceId, (long)envelope.Envelope.Timestamp, true); + sessionVersion = (uint)sealedSessionCipher.GetSessionVersion(new SignalProtocolAddress(metadata.Sender, (uint)metadata.SenderDevice)); } else { throw new InvalidMessageException("Unknown type: " + envelope.GetEnvelopeType() + " from " + envelope.GetSource()); } - - PushTransportDetails transportDetails = new PushTransportDetails(sessionVersion); - byte[] data = transportDetails.GetStrippedPaddingMessageBody(paddedMessage); + var data = GetStrippedMessage(sessionVersion, paddedMessage); return new Plaintext(metadata, data); } catch (DuplicateMessageException e) @@ -214,7 +256,15 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext) { throw new ProtocolNoSessionException(e, envelope.GetSource(), envelope.GetSourceDevice()); } + } + private static byte[] GetStrippedMessage(uint sessionVersion, byte[] paddedMessage) + { + PushTransportDetails transportDetails = new PushTransportDetails(sessionVersion); + byte[] data = transportDetails.GetStrippedPaddingMessageBody(paddedMessage); + return data; + } + private SignalServiceDataMessage CreateSignalServiceMessage(Metadata metadata, DataMessage content) {