Skip to content

Added callback support and as async (requires change to libsignal-protocol-dotnet for it) #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 109 additions & 59 deletions libsignal-service-dotnet/crypto/SignalServiceCipher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using libsignalservicedotnet.crypto;
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using System.Linq;

namespace libsignalservice.crypto
Expand Down Expand Up @@ -72,77 +73,108 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, Unidentifi
/// Decrypt a received <see cref="SignalServiceEnvelope"/>
/// </summary>
/// <param name="envelope">The received SignalServiceEnvelope</param>
/// <param name="callback">Optional callback to call during the decrypt process before it is acked</param>
/// <returns>a decrypted SignalServiceContent</returns>
public SignalServiceContent? Decrypt(SignalServiceEnvelope envelope)
public async Task<SignalServiceContent?> Decrypt(SignalServiceEnvelope envelope, Func<SignalServiceContent?, Task> callback = null)
{
Func<Plaintext, Task> callback_func = null;
if (callback != null)
{
callback_func = async (data) => await callback(await DecryptComplete(envelope, data));
}
try
{
Plaintext plaintext = null;
if (envelope.HasLegacyMessage())
{
Plaintext plaintext = Decrypt(envelope, envelope.GetLegacyMessage());
DataMessage message = DataMessage.Parser.ParseFrom(plaintext.Data);
plaintext = await Decrypt(envelope, envelope.GetLegacyMessage(), callback_func);
}
else if (envelope.HasContent())
{
plaintext = await Decrypt(envelope, envelope.GetContent(), callback_func);
}
if (callback_func != null)
{
return null;
}
return await DecryptComplete(envelope, plaintext);
}
catch (InvalidProtocolBufferException e)
{
throw new InvalidMessageException(e);
}
}
private async Task<SignalServiceContent> DecryptComplete(SignalServiceEnvelope envelope, Plaintext plaintext)
{
if (envelope.HasLegacyMessage())
{
DataMessage message = DataMessage.Parser.ParseFrom(plaintext.Data);
return new SignalServiceContent(plaintext.Metadata.Sender,
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
{
Message = CreateSignalServiceMessage(plaintext.Metadata, message)
};
}
else if (envelope.HasContent())
{
Content message = Content.Parser.ParseFrom(plaintext.Data);
if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
{
return new SignalServiceContent(plaintext.Metadata.Sender,
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
{
Message = CreateSignalServiceMessage(plaintext.Metadata, message)
Message = CreateSignalServiceMessage(plaintext.Metadata, message.DataMessage)
};
}
else if (envelope.HasContent())
else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage)
{
Plaintext plaintext = Decrypt(envelope, envelope.Envelope.Content.ToByteArray());
Content message = Content.Parser.ParseFrom(plaintext.Data);
if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
{
return new SignalServiceContent(plaintext.Metadata.Sender,
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
{
Message = CreateSignalServiceMessage(plaintext.Metadata, message.DataMessage)
};
}
else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage)
return new SignalServiceContent(plaintext.Metadata.Sender,
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
{
return new SignalServiceContent(plaintext.Metadata.Sender,
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
{
SynchronizeMessage = CreateSynchronizeMessage(plaintext.Metadata, message.SyncMessage)
};
}
else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
SynchronizeMessage = CreateSynchronizeMessage(plaintext.Metadata, message.SyncMessage)
};
}
else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
{
return new SignalServiceContent(plaintext.Metadata.Sender,
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
{
return new SignalServiceContent(plaintext.Metadata.Sender,
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
{
CallMessage = CreateCallMessage(message.CallMessage)
};
}
else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
CallMessage = CreateCallMessage(message.CallMessage)
};
}
else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
{
return new SignalServiceContent(plaintext.Metadata.Sender,
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
{
return new SignalServiceContent(plaintext.Metadata.Sender,
plaintext.Metadata.SenderDevice,
plaintext.Metadata.Timestamp,
plaintext.Metadata.NeedsReceipt)
{
ReadMessage = CreateReceiptMessage(plaintext.Metadata, message.ReceiptMessage)
};
}
ReadMessage = CreateReceiptMessage(plaintext.Metadata, message.ReceiptMessage)
};
}
return null;
}
catch (InvalidProtocolBufferException e)
return null;
}
private class DecryptionCallbackHandler : DecryptionCallback
{
public Task handlePlaintext(byte[] data, uint sessionVersion)
{
throw new InvalidMetadataMessageException(e);
data = GetStrippedMessage(sessionVersion, data);
return callback(new Plaintext(metadata, data));
}
public SessionCipher sessionCipher;
public Metadata metadata;
public Func<Plaintext, Task> callback;
}

private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
private async Task<Plaintext> Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext, Func<Plaintext, Task> callback = null)
{
try
{
Expand All @@ -153,15 +185,27 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
byte[] paddedMessage;
Metadata metadata;
uint sessionVersion;

DecryptionCallbackHandler callback_handler = null;
if (callback != null)
callback_handler = new DecryptionCallbackHandler { callback = callback, sessionCipher = sessionCipher };
if (envelope.IsPreKeySignalMessage())
{
paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext));
metadata = new Metadata(envelope.GetSource(), envelope.GetSourceDevice(), envelope.GetTimestamp(), false);
if (callback_handler != null)
{
await sessionCipher.decrypt(new PreKeySignalMessage(ciphertext), callback_handler);
return null;
}
paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext));
sessionVersion = sessionCipher.getSessionVersion();
}
else if (envelope.IsSignalMessage())
{
if (callback_handler != null)
{
await sessionCipher.decrypt(new SignalMessage(ciphertext), callback_handler);
return null;
}
paddedMessage = sessionCipher.decrypt(new SignalMessage(ciphertext));
metadata = new Metadata(envelope.GetSource(), envelope.GetSourceDevice(), envelope.GetTimestamp(), false);
sessionVersion = sessionCipher.getSessionVersion();
Expand All @@ -170,16 +214,14 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
{
var results = sealedSessionCipher.Decrypt(CertificateValidator, ciphertext, (long)envelope.Envelope.ServerTimestamp);
paddedMessage = results.Item2;
metadata = new Metadata(results.Item1.Name, (int) results.Item1.DeviceId, (long) envelope.Envelope.Timestamp, true);
sessionVersion = (uint) sealedSessionCipher.GetSessionVersion(new SignalProtocolAddress(metadata.Sender, (uint) metadata.SenderDevice));
metadata = new Metadata(results.Item1.Name, (int)results.Item1.DeviceId, (long)envelope.Envelope.Timestamp, true);
sessionVersion = (uint)sealedSessionCipher.GetSessionVersion(new SignalProtocolAddress(metadata.Sender, (uint)metadata.SenderDevice));
}
else
{
throw new InvalidMessageException("Unknown type: " + envelope.GetEnvelopeType() + " from " + envelope.GetSource());
}

PushTransportDetails transportDetails = new PushTransportDetails(sessionVersion);
byte[] data = transportDetails.GetStrippedPaddingMessageBody(paddedMessage);
var data = GetStrippedMessage(sessionVersion, paddedMessage);
return new Plaintext(metadata, data);
}
catch (DuplicateMessageException e)
Expand Down Expand Up @@ -214,7 +256,15 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
{
throw new ProtocolNoSessionException(e, envelope.GetSource(), envelope.GetSourceDevice());
}

}
private static byte[] GetStrippedMessage(uint sessionVersion, byte[] paddedMessage)
{
PushTransportDetails transportDetails = new PushTransportDetails(sessionVersion);
byte[] data = transportDetails.GetStrippedPaddingMessageBody(paddedMessage);
return data;
}


private SignalServiceDataMessage CreateSignalServiceMessage(Metadata metadata, DataMessage content)
{
Expand Down