Skip to content

Commit c5eb667

Browse files
committed
Fix conflicts and rework serialization with compression
1 parent 533b71f commit c5eb667

12 files changed

+154
-51
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ set(SPARROW_IPC_HEADERS
109109
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/config.hpp
110110
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/sparrow_ipc_version.hpp
111111
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/compression.hpp
112-
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp
113112
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_fixedsizebinary_array.hpp
114113
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_primitive_array.hpp
115114
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_utils.hpp

include/sparrow_ipc/chunk_memory_serializer.hpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
#pragma once
22

3+
#include <numeric>
4+
#include <optional>
5+
#include <ranges>
6+
#include <stdexcept>
7+
#include <vector>
8+
39
#include <sparrow/record_batch.hpp>
410

11+
#include "Message_generated.h"
12+
13+
#include "sparrow_ipc/any_output_stream.hpp"
514
#include "sparrow_ipc/chunk_memory_output_stream.hpp"
615
#include "sparrow_ipc/config/config.hpp"
716
#include "sparrow_ipc/memory_output_stream.hpp"
@@ -33,8 +42,9 @@ namespace sparrow_ipc
3342
* @brief Constructs a chunk serializer with a reference to a chunked memory output stream.
3443
*
3544
* @param stream Reference to a chunked memory output stream that will receive the serialized chunks
45+
* @param compression Optional: The compression type to use for record batch bodies.
3646
*/
37-
chunk_serializer(chunked_memory_output_stream<std::vector<std::vector<uint8_t>>>& stream);
47+
chunk_serializer(chunked_memory_output_stream<std::vector<std::vector<uint8_t>>>& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt);
3848

3949
/**
4050
* @brief Writes a single record batch to the chunked stream.
@@ -120,6 +130,7 @@ namespace sparrow_ipc
120130
std::vector<sparrow::data_type> m_dtypes;
121131
chunked_memory_output_stream<std::vector<std::vector<uint8_t>>>* m_pstream;
122132
bool m_ended{false};
133+
std::optional<org::apache::arrow::flatbuf::CompressionType> m_compression;
123134
};
124135

125136
// Implementation
@@ -133,7 +144,21 @@ namespace sparrow_ipc
133144
throw std::runtime_error("Cannot append record batches to a serializer that has been ended");
134145
}
135146

136-
m_pstream->reserve((m_schema_received ? 0 : 1) + m_pstream->size() + record_batches.size());
147+
const auto reserve_function = [&record_batches, this]()
148+
{
149+
return std::accumulate(
150+
record_batches.begin(),
151+
record_batches.end(),
152+
m_pstream->size(),
153+
[this](size_t acc, const sparrow::record_batch& rb)
154+
{
155+
return acc + calculate_record_batch_message_size(rb, m_compression);
156+
}
157+
)
158+
+ (m_schema_received ? 0 : calculate_schema_message_size(*record_batches.begin()));
159+
};
160+
161+
m_pstream->reserve(reserve_function);
137162

138163
if (!m_schema_received)
139164
{
@@ -148,10 +173,14 @@ namespace sparrow_ipc
148173

149174
for (const auto& rb : record_batches)
150175
{
176+
if (get_column_dtypes(rb) != m_dtypes)
177+
{
178+
throw std::invalid_argument("Record batch schema does not match serializer schema");
179+
}
151180
std::vector<uint8_t> buffer;
152181
memory_output_stream stream(buffer);
153182
any_output_stream astream(stream);
154-
serialize_record_batch(rb, astream);
183+
serialize_record_batch(rb, astream, m_compression);
155184
m_pstream->write(std::move(buffer));
156185
}
157186
}
@@ -169,4 +198,4 @@ namespace sparrow_ipc
169198
write(record_batches);
170199
return *this;
171200
}
172-
}
201+
}

include/sparrow_ipc/flatbuffer_utils.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,20 @@ namespace sparrow_ipc
213213
* format that conforms to the Arrow IPC specification.
214214
*
215215
* @param record_batch The source record batch containing the data to be serialized
216-
*
216+
* @param compression Optional: The compression algorithm to be used for the message body
217+
* @param body_size Optional: An override for the total size of the message body
218+
* If not provided, the size is calculated from the uncompressed buffers
219+
* This is required when using compression
220+
* @param compressed_buffers Optional: A pointer to a vector of buffer metadata.
221+
* If provided, this metadata is used instead of generating it from the
222+
* uncompressed record batch. This is required when using compression.
217223
* @return A FlatBufferBuilder containing the complete serialized message ready for
218224
* transmission or storage. The builder is finished and ready to be accessed
219225
* via GetBufferPointer() and GetSize().
220226
*
221227
* @note The returned message uses Arrow IPC format version V5
222-
* @note Compression and variadic buffer counts are not currently implemented (set to 0)
223-
* @note The body size is automatically calculated based on the record batch contents
228+
* @note Variadic buffer counts is not currently implemented (set to 0)
224229
*/
225230
[[nodiscard]] flatbuffers::FlatBufferBuilder
226-
get_record_batch_message_builder(const sparrow::record_batch& record_batch);
227-
}
231+
get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt, std::optional<std::int64_t> body_size = std::nullopt, const std::vector<org::apache::arrow::flatbuf::Buffer>* compressed_buffers = nullptr);
232+
}

include/sparrow_ipc/serialize.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ namespace sparrow_ipc
2525
* @tparam R Container type that holds record batches (must support empty(), operator[], begin(), end())
2626
* @param record_batches Collection of record batches to serialize. All batches must have identical
2727
* schemas.
28-
* @param compression The compression type to use when serializing
29-
*
3028
* @param stream The output stream where the serialized data will be written.
29+
* @param compression The compression type to use when serializing.
3130
*
3231
* @throws std::invalid_argument If record batches have inconsistent schemas or if the collection
3332
* contains batches that cannot be serialized together.
@@ -70,13 +69,14 @@ namespace sparrow_ipc
7069
*
7170
* @param record_batch The sparrow record batch to serialize
7271
* @param stream The output stream where the serialized record batch will be written
72+
* @param compression The compression type to use when serializing.
7373
*
7474
* @note The output follows Arrow IPC message format with proper alignment and
7575
* includes both metadata and data portions of the record batch
7676
*/
7777

7878
SPARROW_IPC_API void
79-
serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream);
79+
serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression);
8080

8181
/**
8282
* @brief Serializes a schema message for a record batch into a byte buffer.

include/sparrow_ipc/serialize_utils.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "Message_generated.h"
99
#include "sparrow_ipc/any_output_stream.hpp"
1010
#include "sparrow_ipc/config/config.hpp"
11-
#include "sparrow_ipc/compression.hpp"
1211
#include "sparrow_ipc/utils.hpp"
1312

1413
namespace sparrow_ipc
@@ -40,8 +39,8 @@ namespace sparrow_ipc
4039
* consists of a metadata section followed by a body section containing the actual data.
4140
*
4241
* @param record_batch The sparrow record batch to be serialized
43-
* @param compression The compression type to use when serializing
4442
* @param stream The output stream where the serialized record batch will be written
43+
* @param compression The compression type to use when serializing
4544
*/
4645
SPARROW_IPC_API void
4746
serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression);
@@ -74,10 +73,11 @@ namespace sparrow_ipc
7473
* - Body data with 8-byte alignment between buffers
7574
*
7675
* @param record_batch The record batch to be measured
76+
* @param compression The compression type to use when serializing
7777
* @return The total size in bytes that the serialized record batch would occupy
7878
*/
7979
[[nodiscard]] SPARROW_IPC_API std::size_t
80-
calculate_record_batch_message_size(const sparrow::record_batch& record_batch);
80+
calculate_record_batch_message_size(const sparrow::record_batch& record_batch, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt);
8181

8282
/**
8383
* @brief Calculates the total serialized size for a collection of record batches.
@@ -87,12 +87,13 @@ namespace sparrow_ipc
8787
*
8888
* @tparam R Range type containing sparrow::record_batch objects
8989
* @param record_batches Collection of record batches to be measured
90+
* @param compression The compression type to use when serializing
9091
* @return The total size in bytes for the complete serialized output
9192
* @throws std::invalid_argument if record batches have inconsistent schemas
9293
*/
9394
template <std::ranges::input_range R>
9495
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
95-
[[nodiscard]] std::size_t calculate_total_serialized_size(const R& record_batches)
96+
[[nodiscard]] std::size_t calculate_total_serialized_size(const R& record_batches, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt)
9697
{
9798
if (record_batches.empty())
9899
{
@@ -111,7 +112,7 @@ namespace sparrow_ipc
111112
// Calculate record batch message sizes
112113
for (const auto& record_batch : record_batches)
113114
{
114-
total_size += calculate_record_batch_message_size(record_batch);
115+
total_size += calculate_record_batch_message_size(record_batch, compression);
115116
}
116117

117118
return total_size;

include/sparrow_ipc/serializer.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ namespace sparrow_ipc
4141
* The serializer stores a pointer to this stream for later use.
4242
*/
4343
template <writable_stream TStream>
44-
serializer(TStream& stream)
45-
: m_stream(stream)
44+
serializer(TStream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt)
45+
: m_stream(stream), m_compression(compression)
4646
{
4747
}
4848

@@ -94,7 +94,7 @@ namespace sparrow_ipc
9494
m_stream.size(),
9595
[this](size_t acc, const sparrow::record_batch& rb)
9696
{
97-
return acc + calculate_record_batch_message_size(rb);
97+
return acc + calculate_record_batch_message_size(rb, m_compression);
9898
}
9999
)
100100
+ (m_schema_received ? 0 : calculate_schema_message_size(*record_batches.begin()));
@@ -115,7 +115,7 @@ namespace sparrow_ipc
115115
{
116116
throw std::invalid_argument("Record batch schema does not match serializer schema");
117117
}
118-
serialize_record_batch(rb, m_stream);
118+
serialize_record_batch(rb, m_stream, m_compression);
119119
}
120120
}
121121

@@ -206,6 +206,7 @@ namespace sparrow_ipc
206206
std::vector<sparrow::data_type> m_dtypes;
207207
any_output_stream m_stream;
208208
bool m_ended{false};
209+
std::optional<org::apache::arrow::flatbuf::CompressionType> m_compression;
209210
};
210211

211212
inline serializer& end_stream(serializer& serializer)

src/chunk_memory_serializer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
namespace sparrow_ipc
88
{
9-
chunk_serializer::chunk_serializer(chunked_memory_output_stream<std::vector<std::vector<uint8_t>>>& stream)
10-
: m_pstream(&stream)
9+
chunk_serializer::chunk_serializer(chunked_memory_output_stream<std::vector<std::vector<uint8_t>>>& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression)
10+
: m_pstream(&stream), m_compression(compression)
1111
{
1212
}
1313

src/flatbuffer_utils.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,23 +562,28 @@ namespace sparrow_ipc
562562
return buffers;
563563
}
564564

565-
flatbuffers::FlatBufferBuilder get_record_batch_message_builder(const sparrow::record_batch& record_batch)
565+
flatbuffers::FlatBufferBuilder get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional<org::apache::arrow::flatbuf::CompressionType> compression, std::optional<std::int64_t> body_size_override, const std::vector<org::apache::arrow::flatbuf::Buffer>* compressed_buffers)
566566
{
567567
const std::vector<org::apache::arrow::flatbuf::FieldNode> nodes = create_fieldnodes(record_batch);
568-
const std::vector<org::apache::arrow::flatbuf::Buffer> buffers = get_buffers(record_batch);
568+
const std::vector<org::apache::arrow::flatbuf::Buffer>& buffers = compressed_buffers ? *compressed_buffers : get_buffers(record_batch);
569569
flatbuffers::FlatBufferBuilder record_batch_builder;
570570
auto nodes_offset = record_batch_builder.CreateVectorOfStructs(nodes);
571571
auto buffers_offset = record_batch_builder.CreateVectorOfStructs(buffers);
572+
flatbuffers::Offset<org::apache::arrow::flatbuf::BodyCompression> compression_offset = 0;
573+
if (compression)
574+
{
575+
compression_offset = org::apache::arrow::flatbuf::CreateBodyCompression(record_batch_builder, compression.value(), org::apache::arrow::flatbuf::BodyCompressionMethod::BUFFER);
576+
}
572577
const auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch(
573578
record_batch_builder,
574579
static_cast<int64_t>(record_batch.nb_rows()),
575580
nodes_offset,
576581
buffers_offset,
577-
0, // TODO: Compression
582+
compression_offset,
578583
0 // TODO :variadic buffer Counts
579584
);
580585

581-
const int64_t body_size = calculate_body_size(record_batch);
586+
const int64_t body_size = body_size_override.value_or(calculate_body_size(record_batch));
582587
const auto record_batch_message_offset = org::apache::arrow::flatbuf::CreateMessage(
583588
record_batch_builder,
584589
org::apache::arrow::flatbuf::MetadataVersion::V5,

src/serialize.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
#include "sparrow_ipc/serialize.hpp"
1+
#include <optional>
22

3+
#include "sparrow_ipc/serialize.hpp"
34
#include "sparrow_ipc/flatbuffer_utils.hpp"
45

56
namespace sparrow_ipc
67
{
78
void common_serialize(
8-
const sparrow::record_batch& record_batch,
99
const flatbuffers::FlatBufferBuilder& builder,
1010
any_output_stream& stream
1111
)
@@ -20,12 +20,23 @@ namespace sparrow_ipc
2020

2121
void serialize_schema_message(const sparrow::record_batch& record_batch, any_output_stream& stream)
2222
{
23-
common_serialize(record_batch, get_schema_message_builder(record_batch), stream);
23+
common_serialize(get_schema_message_builder(record_batch), stream);
2424
}
2525

26-
void serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream)
26+
void serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression)
2727
{
28-
common_serialize(record_batch, get_record_batch_message_builder(record_batch), stream);
29-
generate_body(record_batch, stream);
28+
if (compression.has_value())
29+
{
30+
// TODO Handle this inside get_record_batch_message_builder
31+
auto [compressed_body, compressed_buffers] = generate_compressed_body_and_buffers(record_batch, compression.value());
32+
common_serialize(get_record_batch_message_builder(record_batch, compression, compressed_body.size(), &compressed_buffers), stream);
33+
// TODO Use something equivalent to generate_body (stream wise, handling children etc)
34+
stream.write(std::span(compressed_body.data(), compressed_body.size()));
35+
}
36+
else
37+
{
38+
common_serialize(get_record_batch_message_builder(record_batch, compression), stream);
39+
generate_body(record_batch, stream);
40+
}
3041
}
31-
}
42+
}

0 commit comments

Comments
 (0)