Skip to content

Commit b771224

Browse files
committed
Remove alloc on benchmarks + extrachecks on DecodeUrlSSE
1 parent 7ad3374 commit b771224

File tree

3 files changed

+91
-3
lines changed

3 files changed

+91
-3
lines changed

benchmark/Benchmark.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,12 @@ public unsafe void RunSSEDecodingBenchmarkUTF16(string[] data, int[] lengths)
299299
for (int i = 0; i < FileContent.Length; i++)
300300
{
301301
string s = FileContent[i];
302-
char[] base64 = s.ToCharArray();
302+
// char[] base64 = s.ToCharArray();
303+
ReadOnlySpan<char> base64 = s.AsSpan();
303304
byte[] dataoutput = output[i];
304305
int bytesConsumed = 0;
305306
int bytesWritten = 0;
306-
SimdBase64.Base64.DecodeFromBase64SSE(base64.AsSpan(), dataoutput, out bytesConsumed, out bytesWritten, false);
307+
SimdBase64.Base64.DecodeFromBase64SSE(base64, dataoutput, out bytesConsumed, out bytesWritten, false);
307308
if (bytesWritten != lengths[i])
308309
{
309310
Console.WriteLine($"Error: {bytesWritten} != {lengths[i]}");

src/Base64SSEUTF16.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,23 @@ private unsafe static OperationStatus InnerDecodeFromBase64SSEUrl(ReadOnlySpan<c
554554
int lastBlockSrcCount = 0;
555555
while ((bufferPtr - startOfBuffer) % 64 != 0 && src < srcEnd)
556556
{
557+
558+
if (!IsValidBase64Index(*src))
559+
{
560+
bytesConsumed = Math.Max(0, (int)(src - srcInit) - lastBlockSrcCount - (int)bufferBytesConsumed);
561+
bytesWritten = Math.Max(0, (int)(dst - dstInit) - (int)bufferBytesWritten);
562+
563+
int remainderBytesConsumed = 0;
564+
int remainderBytesWritten = 0;
565+
566+
OperationStatus result =
567+
Base64WithWhiteSpaceToBinaryScalar(source.Slice(Math.Max(0, bytesConsumed)), dest.Slice(Math.Max(0, bytesWritten)), out remainderBytesConsumed, out remainderBytesWritten, isUrl);
568+
569+
bytesConsumed += remainderBytesConsumed;
570+
bytesWritten += remainderBytesWritten;
571+
return result;
572+
}
573+
557574
byte val = toBase64[(int)*src];
558575
*bufferPtr = val;
559576
if (val > 64)
@@ -625,6 +642,14 @@ private unsafe static OperationStatus InnerDecodeFromBase64SSEUrl(ReadOnlySpan<c
625642

626643
while (leftover < 4 && src < srcEnd)
627644
{
645+
646+
if (!IsValidBase64Index(*src))
647+
{
648+
bytesConsumed = (int)(src - srcInit);
649+
bytesWritten = (int)(dst - dstInit);
650+
return OperationStatus.InvalidData;
651+
}
652+
628653
byte val = toBase64[(byte)*src];
629654
if (val > 64)
630655
{

test/Base64DecodingTestsUTF16.cs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ protected void RoundtripBase64UrlUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Bas
305305
#pragma warning disable CA5394 // Do not use insecure randomness
306306
random.NextBytes(source);
307307

308-
string base64String = Convert.ToBase64String(source).Replace('+', '-').Replace('/', '_'); ;
308+
string base64String = Convert.ToBase64String(source).Replace('+', '-').Replace('/', '_');
309309

310310
byte[] decodedBytes = new byte[len];
311311

@@ -1372,6 +1372,68 @@ public void TruncatedCharErrorUTF16SSE()
13721372
TruncatedCharErrorUTF16(Base64.DecodeFromBase64SSE,Base64.SafeBase64ToBinaryWithWhiteSpace);
13731373
}
13741374

1375+
protected void TruncatedCharErrorUrlUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16,DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16)
1376+
{
1377+
1378+
string badNonASCIIString = "♡♡♡♡";
1379+
1380+
for (int len = 0; len < 2048; len++)
1381+
{
1382+
byte[] source = new byte[len];
1383+
1384+
for (int trial = 0; trial < 10; trial++)
1385+
{
1386+
int bytesConsumed = 0;
1387+
int bytesWritten = 0;
1388+
#pragma warning disable CA5394 // Do not use insecure randomness
1389+
random.NextBytes(source); // Generate random bytes for source
1390+
1391+
string base64 = Convert.ToBase64String(source).Replace('+', '-').Replace('/', '_');
1392+
1393+
int location = random.Next(0, base64.Length + 1)/4;
1394+
1395+
char[] base64WithGarbage = base64.Insert(location, badNonASCIIString).ToCharArray();
1396+
1397+
// Prepare a buffer for decoding the base64 back to binary
1398+
byte[] back = new byte[Base64.MaximalBinaryLengthFromBase64Scalar<char>(base64WithGarbage)];
1399+
1400+
// Attempt to decode base64 back to binary and assert that it fails with INVALID_BASE64_CHARACTER
1401+
var result = Base64WithWhiteSpaceToBinaryFromUTF16(
1402+
base64WithGarbage.AsSpan(), back.AsSpan(),
1403+
out bytesConsumed, out bytesWritten, isUrl: true);
1404+
Assert.True(OperationStatus.InvalidData == result, $"OperationStatus {result} is not Invalid Data, error at location {location}. ");
1405+
Assert.Equal(location, bytesConsumed);
1406+
Assert.Equal(location / 4 * 3, bytesWritten);
1407+
1408+
// Also test safe decoding with a specified back_length
1409+
var safeResult = DecodeFromBase64DelegateSafeFromUTF16(
1410+
base64WithGarbage.AsSpan(), back.AsSpan(),
1411+
out bytesConsumed, out bytesWritten, isUrl: true);
1412+
Assert.Equal(OperationStatus.InvalidData, safeResult);
1413+
Assert.Equal(location, bytesConsumed);
1414+
Assert.Equal(location / 4 * 3, bytesWritten);
1415+
1416+
}
1417+
}
1418+
1419+
1420+
}
1421+
1422+
[Fact]
1423+
[Trait("Category", "scalar")]
1424+
public void TruncatedCharErrorUrlScalarUTF16()
1425+
{
1426+
TruncatedCharErrorUrlUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar,Base64.SafeBase64ToBinaryWithWhiteSpace);
1427+
}
1428+
1429+
1430+
[Fact]
1431+
[Trait("Category", "sse")]
1432+
public void TruncatedCharErrorUrlUTF16SSE()
1433+
{
1434+
TruncatedCharErrorUrlUTF16(Base64.DecodeFromBase64SSE,Base64.SafeBase64ToBinaryWithWhiteSpace);
1435+
}
1436+
13751437

13761438
}
13771439

0 commit comments

Comments
 (0)