Skip to content

Commit b882663

Browse files
committed
Refactor to buffer output using SdkByteChannel
1 parent 4b4c628 commit b882663

File tree

3 files changed

+48
-29
lines changed

3 files changed

+48
-29
lines changed

runtime/runtime-core/native/src/aws/smithy/kotlin/runtime/compression/GzipCompressor.kt

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@
55
package aws.smithy.kotlin.runtime.compression
66

77
import aws.sdk.kotlin.crt.Closeable
8+
import aws.smithy.kotlin.runtime.io.SdkByteChannel
9+
import aws.smithy.kotlin.runtime.io.readFully
10+
import aws.smithy.kotlin.runtime.io.write
811
import kotlinx.cinterop.*
912
import platform.zlib.*
13+
import aws.smithy.kotlin.runtime.io.SdkBuffer
14+
import aws.smithy.kotlin.runtime.io.readToByteArray
15+
import aws.smithy.kotlin.runtime.io.use
1016

1117
private const val DEFAULT_WINDOW_BITS = 15 // Default window bits
1218
private const val WINDOW_BITS_GZIP_OFFSET = 16 // Gzip offset for window bits
@@ -20,13 +26,12 @@ internal class GzipCompressor : Closeable {
2026
internal const val BUFFER_SIZE = 16384
2127
}
2228

23-
private val buffer = ByteArray(BUFFER_SIZE)
2429
private val stream = nativeHeap.alloc<z_stream>()
25-
private val outputBuffer = ArrayList<Byte>()
30+
private val outputBuffer = SdkByteChannel()
2631
internal var isClosed = false
2732

2833
internal val availableForRead: Int
29-
get() = outputBuffer.size
34+
get() = outputBuffer.availableForRead
3035

3136
init {
3237
// Initialize deflate with gzip encoding
@@ -47,22 +52,26 @@ internal class GzipCompressor : Closeable {
4752
/**
4853
* Update the compressor with [input] bytes
4954
*/
50-
fun update(input: ByteArray) = memScoped {
55+
suspend fun update(input: ByteArray) = memScoped {
56+
check (!isClosed) { "Compressor is closed" }
57+
5158
val inputPin = input.pin()
5259

5360
stream.next_in = inputPin.addressOf(0).reinterpret()
5461
stream.avail_in = input.size.toUInt()
5562

63+
val compressionBuffer = ByteArray(BUFFER_SIZE)
64+
5665
while (stream.avail_in > 0u) {
57-
val outputPin = buffer.pin()
66+
val outputPin = compressionBuffer.pin()
5867
stream.next_out = outputPin.addressOf(0).reinterpret()
5968
stream.avail_out = BUFFER_SIZE.toUInt()
6069

6170
val deflateResult = deflate(stream.ptr, Z_NO_FLUSH)
6271
check(deflateResult == Z_OK) { "Deflate failed with error code $deflateResult" }
6372

6473
val bytesWritten = BUFFER_SIZE - stream.avail_out.toInt()
65-
outputBuffer.addAll(buffer.take(bytesWritten))
74+
outputBuffer.write(compressionBuffer, 0, bytesWritten)
6675

6776
outputPin.unpin()
6877
}
@@ -73,43 +82,48 @@ internal class GzipCompressor : Closeable {
7382
/**
7483
* Consume [count] gzip-compressed bytes.
7584
*/
76-
fun consume(count: Int): ByteArray {
85+
suspend fun consume(count: Int): ByteArray {
86+
check (!isClosed) { "Compressor is closed" }
7787
require(count in 0..availableForRead) {
7888
"Count must be between 0 and $availableForRead, got $count"
7989
}
8090

81-
val result = outputBuffer.take(count).toByteArray()
82-
repeat(count) { outputBuffer.removeAt(0) }
83-
return result
91+
return SdkBuffer().use {
92+
outputBuffer.readFully(it, count.toLong())
93+
it.readToByteArray()
94+
}
8495
}
8596

8697
/**
8798
* Flush the compressor and return the terminal sequence of bytes that represent the end of the gzip compression.
8899
*/
89-
fun flush(): ByteArray {
90-
if (isClosed) {
91-
return byteArrayOf()
92-
}
100+
suspend fun flush(): ByteArray {
101+
check (!isClosed) { "Compressor is closed" }
93102

94103
memScoped {
95-
var finished = false
104+
val compressionBuffer = ByteArray(BUFFER_SIZE)
105+
var deflateResult: Int? = null
106+
var outputLength = 0L
96107

97-
while (!finished) {
98-
val outputPin = buffer.pin()
108+
do {
109+
val outputPin = compressionBuffer.pin()
99110
stream.next_out = outputPin.addressOf(0).reinterpret()
100111
stream.avail_out = BUFFER_SIZE.toUInt()
101112

102-
val deflateResult = deflate(stream.ptr, Z_FINISH)
113+
deflateResult = deflate(stream.ptr, Z_FINISH)
103114
check(deflateResult == Z_OK || deflateResult == Z_STREAM_END) { "Deflate failed during finish with error code $deflateResult" }
104115

105116
val bytesWritten = BUFFER_SIZE - stream.avail_out.toInt()
106-
outputBuffer.addAll(buffer.take(bytesWritten))
117+
outputBuffer.write(compressionBuffer, 0, bytesWritten)
107118

108-
finished = deflateResult == Z_STREAM_END
119+
outputLength += bytesWritten.toLong()
109120
outputPin.unpin()
110-
}
121+
} while (deflateResult != Z_STREAM_END)
111122

112-
return outputBuffer.toByteArray()
123+
return SdkBuffer().use {
124+
outputBuffer.readFully(it, outputLength)
125+
it.readByteArray()
126+
}
113127
}
114128
}
115129

runtime/runtime-core/native/src/aws/smithy/kotlin/runtime/compression/GzipNative.kt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import aws.smithy.kotlin.runtime.io.GzipByteReadChannel
1010
import aws.smithy.kotlin.runtime.io.GzipSdkSource
1111
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
1212
import aws.smithy.kotlin.runtime.io.SdkSource
13+
import kotlinx.coroutines.runBlocking
1314

1415
/**
1516
* The gzip compression algorithm.
@@ -39,10 +40,12 @@ public actual class Gzip : CompressionAlgorithm {
3940
if (sourceBytes.isEmpty()) {
4041
stream
4142
} else {
42-
val compressed = GzipCompressor().use {
43-
it.apply {
44-
update(sourceBytes)
45-
}.flush()
43+
val compressed = runBlocking {
44+
GzipCompressor().use {
45+
it.apply {
46+
update(sourceBytes)
47+
}.flush()
48+
}
4649
}
4750

4851
ByteStream.fromBytes(compressed)

runtime/runtime-core/native/src/aws/smithy/kotlin/runtime/io/GzipSdkSourceNative.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package aws.smithy.kotlin.runtime.io
66

77
import aws.smithy.kotlin.runtime.InternalApi
88
import aws.smithy.kotlin.runtime.compression.GzipCompressor
9+
import kotlinx.coroutines.runBlocking
910

1011
/**
1112
* Wraps an [SdkSource], compressing bytes read into GZIP format.
@@ -27,13 +28,14 @@ public actual class GzipSdkSource actual constructor(public val source: SdkSourc
2728

2829
if (rc > 0) {
2930
val input = temp.readByteArray(rc)
30-
compressor.update(input)
31+
runBlocking { compressor.update(input) }
32+
3133
}
3234
}
3335

3436
// If still no data is available, we've hit EOF. Close the compressor and write the remaining bytes
3537
if (compressor.availableForRead == 0) {
36-
val terminationBytes = compressor.flush()
38+
val terminationBytes = runBlocking { compressor.flush() }
3739
sink.write(terminationBytes)
3840
return terminationBytes.size.toLong().also {
3941
compressor.close()
@@ -42,7 +44,7 @@ public actual class GzipSdkSource actual constructor(public val source: SdkSourc
4244

4345
// Read compressed bytes from the compressor
4446
val bytesToRead = minOf(limit, compressor.availableForRead.toLong())
45-
val compressed = compressor.consume(bytesToRead.toInt())
47+
val compressed = runBlocking { compressor.consume(bytesToRead.toInt()) }
4648
sink.write(compressed)
4749
return compressed.size.toLong()
4850
}

0 commit comments

Comments
 (0)