11#pragma once
22
3+ #include < numeric>
4+ #include < optional>
5+ #include < ranges>
6+ #include < stdexcept>
7+ #include < vector>
8+
39#include < sparrow/record_batch.hpp>
410
11+ #include " Message_generated.h"
12+
13+ #include " sparrow_ipc/any_output_stream.hpp"
514#include " sparrow_ipc/chunk_memory_output_stream.hpp"
615#include " sparrow_ipc/config/config.hpp"
716#include " sparrow_ipc/memory_output_stream.hpp"
@@ -33,8 +42,9 @@ namespace sparrow_ipc
3342 * @brief Constructs a chunk serializer with a reference to a chunked memory output stream.
3443 *
3544 * @param stream Reference to a chunked memory output stream that will receive the serialized chunks
45+ * @param compression Optional: The compression type to use for record batch bodies.
3646 */
37- chunk_serializer (chunked_memory_output_stream<std::vector<std::vector<uint8_t >>>& stream);
47+ chunk_serializer (chunked_memory_output_stream<std::vector<std::vector<uint8_t >>>& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std:: nullopt );
3848
3949 /* *
4050 * @brief Writes a single record batch to the chunked stream.
@@ -120,6 +130,7 @@ namespace sparrow_ipc
120130 std::vector<sparrow::data_type> m_dtypes;
121131 chunked_memory_output_stream<std::vector<std::vector<uint8_t >>>* m_pstream;
122132 bool m_ended{false };
133+ std::optional<org::apache::arrow::flatbuf::CompressionType> m_compression;
123134 };
124135
125136 // Implementation
@@ -133,7 +144,21 @@ namespace sparrow_ipc
133144 throw std::runtime_error (" Cannot append record batches to a serializer that has been ended" );
134145 }
135146
136- m_pstream->reserve ((m_schema_received ? 0 : 1 ) + m_pstream->size () + record_batches.size ());
147+ const auto reserve_function = [&record_batches, this ]()
148+ {
149+ return std::accumulate (
150+ record_batches.begin (),
151+ record_batches.end (),
152+ m_pstream->size (),
153+ [this ](size_t acc, const sparrow::record_batch& rb)
154+ {
155+ return acc + calculate_record_batch_message_size (rb, m_compression);
156+ }
157+ )
158+ + (m_schema_received ? 0 : calculate_schema_message_size (*record_batches.begin ()));
159+ };
160+
161+ m_pstream->reserve (reserve_function);
137162
138163 if (!m_schema_received)
139164 {
@@ -148,10 +173,14 @@ namespace sparrow_ipc
148173
149174 for (const auto & rb : record_batches)
150175 {
176+ if (get_column_dtypes (rb) != m_dtypes)
177+ {
178+ throw std::invalid_argument (" Record batch schema does not match serializer schema" );
179+ }
151180 std::vector<uint8_t > buffer;
152181 memory_output_stream stream (buffer);
153182 any_output_stream astream (stream);
154- serialize_record_batch (rb, astream);
183+ serialize_record_batch (rb, astream, m_compression );
155184 m_pstream->write (std::move (buffer));
156185 }
157186 }
@@ -169,4 +198,4 @@ namespace sparrow_ipc
169198 write (record_batches);
170199 return *this ;
171200 }
172- }
201+ }
0 commit comments