diff --git a/Xledger.Collections.Test/TestMemoryOwner.cs b/Xledger.Collections.Test/TestMemoryOwner.cs index 1906e7d..215db00 100644 --- a/Xledger.Collections.Test/TestMemoryOwner.cs +++ b/Xledger.Collections.Test/TestMemoryOwner.cs @@ -23,7 +23,6 @@ public void TestSlice() { Assert.Equal(array.AsMemory().Slice(3).Slice(1, 4), sliced2.Memory); } -#if NET [Fact] public void TestStream_ToMemoryOwner() { // This length should be larger than the default GetCopyBufferSize. @@ -31,7 +30,15 @@ public void TestStream_ToMemoryOwner() { new Random().NextBytes(array); var ms = new MemoryStream(array); using var memoryOwner = ms.ToOwnedMemory(); +#if NET Assert.Equal(array.AsMemory(), memoryOwner.Memory); +#else + var mem = memoryOwner.Memory; + Assert.Equal(array.Length, mem.Length); + for (int i = 0; i < array.Length; ++i) { + Assert.Equal(array[i], mem.Span[i]); + } +#endif } [Fact] @@ -41,7 +48,67 @@ public async Task TestStream_ToMemoryOwnerAsync() { new Random().NextBytes(array); var ms = new MemoryStream(array); using var memoryOwner = await ms.ToOwnedMemoryAsync(); +#if NET Assert.Equal(array.AsMemory(), memoryOwner.Memory); +#else + var mem = memoryOwner.Memory; + Assert.Equal(array.Length, mem.Length); + for (int i = 0; i < array.Length; ++i) { + Assert.Equal(array[i], mem.Span[i]); + } +#endif } + + // Tests that reading from a stream with an unknown size does so correctly. + [Fact] + public void TestUnsizedStream_ToMemoryOwner() { + // This length should be larger than the default GetCopyBufferSize. + byte[] array = new byte[2 * 1024 * 1024 + 17]; + new Random().NextBytes(array); + var ms = new UnsizedMemoryStream(array); + using var memoryOwner = ms.ToOwnedMemory(); +#if NET + Assert.Equal(array.AsMemory(), memoryOwner.Memory); +#else + var mem = memoryOwner.Memory; + Assert.Equal(array.Length, mem.Length); + for (int i = 0; i < array.Length; ++i) { + Assert.Equal(array[i], mem.Span[i]); + } #endif + } + + // Tests that reading from a stream with an unknown size does so correctly. + [Fact] + public async Task TestUnsizedStream_ToMemoryOwnerAsync() { + // This length should be larger than the default GetCopyBufferSize. + byte[] array = new byte[2 * 1024 * 1024 + 17]; + new Random().NextBytes(array); + var ms = new UnsizedMemoryStream(array); + using var memoryOwner = await ms.ToOwnedMemoryAsync(); +#if NET + Assert.Equal(array.AsMemory(), memoryOwner.Memory); +#else + var mem = memoryOwner.Memory; + Assert.Equal(array.Length, mem.Length); + for (int i = 0; i < array.Length; ++i) { + Assert.Equal(array[i], mem.Span[i]); + } +#endif + } + + class UnsizedMemoryStream(byte[] buffer) : MemoryStream(buffer) { + public override bool CanSeek => false; + + public override int Capacity { + get => throw new NotImplementedException(); + set => throw new NotImplementedException(); + } + + public override long Length => throw new NotImplementedException(); + + public override void SetLength(long value) { + throw new NotImplementedException(); + } + } } diff --git a/Xledger.Collections/Memory/Extensions.cs b/Xledger.Collections/Memory/Extensions.cs index df0fc7d..da6eca8 100644 --- a/Xledger.Collections/Memory/Extensions.cs +++ b/Xledger.Collections/Memory/Extensions.cs @@ -24,23 +24,32 @@ public static IMemoryOwner Slice(this IMemoryOwner memoryOwner, int sta return new SizedMemoryOwner(memoryOwner, start, length); } + static readonly int ArrayMaxLength = #if NET + Array.MaxLength; +#else + int.MaxValue; +#endif + + static readonly ThreadLocal PROBE = new(() => new byte[1]); + public static IMemoryOwner ToOwnedMemory(this Stream source, bool leaveOpen = false) { if (source == null) { throw new ArgumentNullException(nameof(source)); } - int initialBufLen = GetCopyBufferSize(source); + (bool canHoldEntireStream, int initialBufLen) = GetBufferSize(source); - var currentOwner = MemoryPool.Shared.Rent(initialBufLen); - var currentBuffer = currentOwner.Memory; + var currentBuffer = ArrayPool.Shared.Rent(initialBufLen); + var currentOwner = currentBuffer.ToOwnedMemory(ArrayPool.Shared); int totalBytesRead = 0; try { while (true) { - var dest = currentBuffer.Slice(totalBytesRead); - - int bytesRead = source.Read(dest.Span); + int bytesRead = source.Read( + currentBuffer, + totalBytesRead, + currentBuffer.Length - totalBytesRead); if (bytesRead == 0) { break; @@ -48,27 +57,36 @@ public static IMemoryOwner ToOwnedMemory(this Stream source, bool leaveOpe totalBytesRead += bytesRead; + if (canHoldEntireStream && totalBytesRead == initialBufLen) { + // We've read the entire stream. + break; + } + if (totalBytesRead != currentBuffer.Length) { continue; } - if (currentBuffer.Length == Array.MaxLength) { + if (currentBuffer.Length == ArrayMaxLength) { +#if NET Span probe = stackalloc byte[1]; if (source.Read(probe) > 0) { - throw new IOException($"Stream exceeds the maximum bufferable array size of {Array.MaxLength} bytes."); +#else + if (source.Read(PROBE.Value, 0, 1) > 0) { +#endif + throw new IOException($"Stream exceeds the maximum bufferable array size of {ArrayMaxLength} bytes."); } break; // we are at the end of the stream } var newCapacity = (long)currentBuffer.Length * 2; - if (newCapacity > Array.MaxLength) { - newCapacity = Array.MaxLength; + if (newCapacity > ArrayMaxLength) { + newCapacity = ArrayMaxLength; } - var newOwner = MemoryPool.Shared.Rent((int)newCapacity); - var newBuffer = newOwner.Memory; + var newBuffer = ArrayPool.Shared.Rent((int)newCapacity); + var newOwner = newBuffer.ToOwnedMemory(ArrayPool.Shared); - currentBuffer.CopyTo(newBuffer); + currentBuffer.CopyTo(newBuffer.AsSpan()); currentOwner.Dispose(); currentOwner = newOwner; currentBuffer = newBuffer; @@ -85,26 +103,26 @@ public static IMemoryOwner ToOwnedMemory(this Stream source, bool leaveOpe return currentOwner.Slice(0, totalBytesRead); } - static readonly byte[] ASYNC_PROBE = new byte[1]; - public static async Task> ToOwnedMemoryAsync(this Stream source, bool leaveOpen = false, CancellationToken tok = default) { if (source == null) { throw new ArgumentNullException(nameof(source)); } - int initialBufLen = GetCopyBufferSize(source); + (bool canHoldEntireStream, int initialBufLen) = GetBufferSize(source); - var currentOwner = MemoryPool.Shared.Rent(initialBufLen); - var currentBuffer = currentOwner.Memory; + var currentBuffer = ArrayPool.Shared.Rent(initialBufLen); + var currentOwner = currentBuffer.ToOwnedMemory(ArrayPool.Shared); int totalBytesRead = 0; try { while (true) { tok.ThrowIfCancellationRequested(); - var dest = currentBuffer.Slice(totalBytesRead); - - int bytesRead = await source.ReadAsync(dest, tok).ConfigureAwait(false); + int bytesRead = await source.ReadAsync( + currentBuffer, + totalBytesRead, + currentBuffer.Length - totalBytesRead, + tok).ConfigureAwait(false); if (bytesRead == 0) { break; @@ -112,26 +130,31 @@ public static async Task> ToOwnedMemoryAsync(this Stream sour totalBytesRead += bytesRead; + if (canHoldEntireStream && totalBytesRead == initialBufLen) { + // We've read the entire stream. + break; + } + if (totalBytesRead != currentBuffer.Length) { continue; } - if (currentBuffer.Length == Array.MaxLength) { - if (await source.ReadAsync(ASYNC_PROBE, tok).ConfigureAwait(false) > 0) { - throw new IOException($"Stream exceeds the maximum bufferable array size of {Array.MaxLength} bytes."); + if (currentBuffer.Length == ArrayMaxLength) { + if (await source.ReadAsync(PROBE.Value, 0, 1, tok).ConfigureAwait(false) > 0) { + throw new IOException($"Stream exceeds the maximum bufferable array size of {ArrayMaxLength} bytes."); } break; // we are at the end of the stream } var newCapacity = (long)currentBuffer.Length * 2; - if (newCapacity > Array.MaxLength) { - newCapacity = Array.MaxLength; + if (newCapacity > ArrayMaxLength) { + newCapacity = ArrayMaxLength; } - var newOwner = MemoryPool.Shared.Rent((int)newCapacity); - var newBuffer = newOwner.Memory; + var newBuffer = ArrayPool.Shared.Rent((int)newCapacity); + var newOwner = newBuffer.ToOwnedMemory(ArrayPool.Shared); - currentBuffer.CopyTo(newBuffer); + currentBuffer.CopyTo(newBuffer.AsSpan()); currentOwner.Dispose(); currentOwner = newOwner; currentBuffer = newBuffer; @@ -148,8 +171,9 @@ public static async Task> ToOwnedMemoryAsync(this Stream sour return currentOwner.Slice(0, totalBytesRead); } - // Copied from System.IO.Stream, adapted to be static - static int GetCopyBufferSize(Stream stream) { + // Initially copied from System.IO.Stream, adapted to be static and to match + // the use above which is to copy an entire stream into a single array. + static (bool isSufficient, int length) GetBufferSize(Stream stream) { // This value was originally picked to be the largest multiple of 4096 that is still smaller than the large object heap threshold (85K). // The CopyTo{Async} buffer is short-lived and is likely to be collected at Gen0, and it offers a significant improvement in Copy // performance. Since then, the base implementations of CopyTo{Async} have been updated to use ArrayPool, which will end up rounding @@ -158,6 +182,7 @@ static int GetCopyBufferSize(Stream stream) { // benefits to using the larger buffer size. So, for now, this value remains. const int DefaultCopyBufferSize = 81920; + bool isSufficient = false; int bufferSize = DefaultCopyBufferSize; if (stream.CanSeek) { @@ -172,16 +197,18 @@ static int GetCopyBufferSize(Stream stream) { bufferSize = 1; } else { long remaining = length - position; - if (remaining > 0) { - // In the case of a positive overflow, stick to the default size - bufferSize = (int)Math.Min(bufferSize, remaining); + if (remaining > ArrayMaxLength) { + throw new IOException($"Stream exceeds the maximum bufferable array size of {ArrayMaxLength} bytes."); + } else if (remaining > 0) { + // If there is some remaining amount in the stream, we copy into a buffer of that size. + isSufficient = true; + bufferSize = (int)remaining; } } } - return bufferSize; + return (isSufficient, bufferSize); } -#endif /// /// Adapt an array to IMemoryOwner. If you pass in an ArrayPool owner, the Array will be returned to the pool on dispose. diff --git a/Xledger.Collections/Xledger.Collections.csproj b/Xledger.Collections/Xledger.Collections.csproj index db3838b..2ac83c6 100644 --- a/Xledger.Collections/Xledger.Collections.csproj +++ b/Xledger.Collections/Xledger.Collections.csproj @@ -6,7 +6,7 @@ Xledger.Collections net48;net8.0 - 12.0 + 13.0 disable disable