Skip to content

Commit 6de389b

Browse files
authored
Support variable_size_binary_view_array (#74)
* Support variable_size_binary_view_array * Skip serialization of last buffer in views * Fix rebase
1 parent 587c242 commit 6de389b

9 files changed

+268
-33
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ set(SPARROW_IPC_HEADERS
133133
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_time_related_arrays.hpp
134134
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_utils.hpp
135135
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp
136+
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_view_array.hpp
136137
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize.hpp
137138
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserializer.hpp
138139
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/encapsulated_message.hpp
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#pragma once
2+
3+
#include <span>
4+
#include <unordered_set>
5+
6+
#include <sparrow/arrow_interface/arrow_array_schema_proxy.hpp>
7+
#include <sparrow/variable_size_binary_view_array.hpp>
8+
9+
#include "Message_generated.h"
10+
#include "sparrow_ipc/arrow_interface/arrow_array.hpp"
11+
#include "sparrow_ipc/arrow_interface/arrow_schema.hpp"
12+
#include "sparrow_ipc/deserialize_utils.hpp"
13+
14+
namespace sparrow_ipc
15+
{
16+
template <typename T>
17+
[[nodiscard]] T deserialize_variable_size_binary_view_array(
18+
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
19+
std::span<const uint8_t> body,
20+
std::string_view name,
21+
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
22+
bool nullable,
23+
size_t& buffer_index,
24+
const int64_t data_buffers_size
25+
)
26+
{
27+
// TODO Use the commented line below instead of the following snippet when this is handled/added in sparrow
28+
// const std::string_view format = data_type_to_format(sparrow::detail::get_data_type_from_array<T>::get());
29+
std::string format;
30+
if (sparrow::detail::get_data_type_from_array<T>::get() == sparrow::data_type::STRING_VIEW)
31+
{
32+
format = "vu";
33+
}
34+
else if (sparrow::detail::get_data_type_from_array<T>::get() == sparrow::data_type::BINARY_VIEW)
35+
{
36+
format = "vz";
37+
}
38+
else
39+
{
40+
throw std::runtime_error("Unsupported view type");
41+
}
42+
43+
// Set up flags based on nullable
44+
std::optional<std::unordered_set<sparrow::ArrowFlag>> flags;
45+
if (nullable)
46+
{
47+
flags = std::unordered_set<sparrow::ArrowFlag>{sparrow::ArrowFlag::NULLABLE};
48+
}
49+
50+
ArrowSchema schema = make_non_owning_arrow_schema(
51+
format,
52+
name.data(),
53+
metadata,
54+
flags,
55+
0,
56+
nullptr,
57+
nullptr
58+
);
59+
60+
const auto compression = record_batch.compression();
61+
std::vector<arrow_array_private_data::optionally_owned_buffer> buffers;
62+
63+
auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index);
64+
auto views_buffer_span = utils::get_buffer(record_batch, body, buffer_index);
65+
66+
if (compression)
67+
{
68+
buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression));
69+
buffers.push_back(utils::get_decompressed_buffer(views_buffer_span, compression));
70+
}
71+
else
72+
{
73+
buffers.push_back(validity_buffer_span);
74+
buffers.push_back(views_buffer_span);
75+
}
76+
77+
// If no data buffers are present, we still need to push an empty data buffer to have things valid in sparrow
78+
if (data_buffers_size == 0)
79+
{
80+
buffers.push_back(arrow_array_private_data::optionally_owned_buffer(std::span<const uint8_t>{}));
81+
}
82+
83+
for (auto i = 0; i < data_buffers_size; ++i)
84+
{
85+
auto data_buffer_span =
86+
utils::get_buffer(record_batch, body, buffer_index);
87+
88+
if (compression)
89+
{
90+
buffers.push_back(
91+
utils::get_decompressed_buffer(data_buffer_span, compression)
92+
);
93+
}
94+
else
95+
{
96+
buffers.push_back(data_buffer_span);
97+
}
98+
}
99+
100+
const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length());
101+
102+
ArrowArray array = make_arrow_array<arrow_array_private_data>(
103+
record_batch.length(),
104+
null_count,
105+
0, // n_children
106+
0, // n_dictionaries
107+
nullptr, // children
108+
nullptr, // dictionary
109+
std::move(buffers)
110+
);
111+
112+
sparrow::arrow_proxy ap{std::move(array), std::move(schema)};
113+
return T{std::move(ap)};
114+
}
115+
}

include/sparrow_ipc/flatbuffer_utils.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#pragma once
2+
3+
#include <ranges>
4+
25
#include <flatbuffers/flatbuffers.h>
36
#include <Message_generated.h>
47

@@ -173,6 +176,8 @@ namespace sparrow_ipc
173176

174177
namespace details
175178
{
179+
std::size_t get_nb_buffers_to_process(const std::string_view& format, const std::size_t orig_buffers_size);
180+
176181
template <typename Func>
177182
void fill_buffers_impl(
178183
const sparrow::arrow_proxy& arrow_proxy,
@@ -182,12 +187,15 @@ namespace sparrow_ipc
182187
)
183188
{
184189
const auto& buffers = arrow_proxy.buffers();
185-
for (const auto& buffer : buffers)
190+
auto nb_buffers = get_nb_buffers_to_process(arrow_proxy.schema().format, buffers.size());
191+
std::ranges::for_each(buffers | std::views::take(nb_buffers),
192+
[&](const auto& buffer)
186193
{
187194
int64_t size = get_buffer_size(buffer);
188195
flatbuf_buffers.emplace_back(offset, size);
189196
offset += utils::align_to_8(size);
190-
}
197+
});
198+
191199
for (const auto& child : arrow_proxy.children())
192200
{
193201
fill_buffers_impl(child, flatbuf_buffers, offset, get_buffer_size);

src/array_deserializer.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ namespace sparrow_ipc
3939
m_deserializer_map[org::apache::arrow::flatbuf::Type::Time] = &deserialize_time;
4040
m_deserializer_map[org::apache::arrow::flatbuf::Type::Null] = &deserialize_null;
4141
m_deserializer_map[org::apache::arrow::flatbuf::Type::Decimal] = &deserialize_decimal;
42+
m_deserializer_map[org::apache::arrow::flatbuf::Type::BinaryView] = &deserialize_variable_size_binary_view<sparrow::binary_view_array>;
43+
m_deserializer_map[org::apache::arrow::flatbuf::Type::Utf8View] = &deserialize_variable_size_binary_view<sparrow::string_view_array>;
4244
}
4345

4446
sparrow::array array_deserializer::deserialize(const org::apache::arrow::flatbuf::RecordBatch& record_batch,
@@ -47,6 +49,7 @@ namespace sparrow_ipc
4749
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
4850
bool nullable,
4951
size_t& buffer_index,
52+
size_t& variadic_counts_idx,
5053
const org::apache::arrow::flatbuf::Field& field) const
5154
{
5255
auto it = m_deserializer_map.find(field.type_type());
@@ -57,7 +60,7 @@ namespace sparrow_ipc
5760
+ " for field '" + name + "'"
5861
);
5962
}
60-
return it->second(record_batch, body, name, metadata, nullable, buffer_index, field);
63+
return it->second(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
6164
}
6265

6366
sparrow::array array_deserializer::deserialize_int(const org::apache::arrow::flatbuf::RecordBatch& record_batch,
@@ -66,6 +69,7 @@ namespace sparrow_ipc
6669
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
6770
bool nullable,
6871
size_t& buffer_index,
72+
size_t& variadic_counts_idx,
6973
const org::apache::arrow::flatbuf::Field& field)
7074
{
7175
const auto* int_type = field.type_as_Int();
@@ -76,21 +80,21 @@ namespace sparrow_ipc
7680
{
7781
switch (bit_width)
7882
{
79-
case BIT_WIDTH_8: return deserialize_primitive<int8_t>(record_batch, body, name, metadata, nullable, buffer_index, field);
80-
case BIT_WIDTH_16: return deserialize_primitive<int16_t>(record_batch, body, name, metadata, nullable, buffer_index, field);
81-
case BIT_WIDTH_32: return deserialize_primitive<int32_t>(record_batch, body, name, metadata, nullable, buffer_index, field);
82-
case BIT_WIDTH_64: return deserialize_primitive<int64_t>(record_batch, body, name, metadata, nullable, buffer_index, field);
83+
case BIT_WIDTH_8: return deserialize_primitive<int8_t>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
84+
case BIT_WIDTH_16: return deserialize_primitive<int16_t>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
85+
case BIT_WIDTH_32: return deserialize_primitive<int32_t>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
86+
case BIT_WIDTH_64: return deserialize_primitive<int64_t>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
8387
default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width));
8488
}
8589
}
8690
else
8791
{
8892
switch (bit_width)
8993
{
90-
case BIT_WIDTH_8: return deserialize_primitive<uint8_t>(record_batch, body, name, metadata, nullable, buffer_index, field);
91-
case BIT_WIDTH_16: return deserialize_primitive<uint16_t>(record_batch, body, name, metadata, nullable, buffer_index, field);
92-
case BIT_WIDTH_32: return deserialize_primitive<uint32_t>(record_batch, body, name, metadata, nullable, buffer_index, field);
93-
case BIT_WIDTH_64: return deserialize_primitive<uint64_t>(record_batch, body, name, metadata, nullable, buffer_index, field);
94+
case BIT_WIDTH_8: return deserialize_primitive<uint8_t>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
95+
case BIT_WIDTH_16: return deserialize_primitive<uint16_t>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
96+
case BIT_WIDTH_32: return deserialize_primitive<uint32_t>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
97+
case BIT_WIDTH_64: return deserialize_primitive<uint64_t>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
9498
default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width));
9599
}
96100
}
@@ -102,15 +106,16 @@ namespace sparrow_ipc
102106
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
103107
bool nullable,
104108
size_t& buffer_index,
109+
size_t& variadic_counts_idx,
105110
const org::apache::arrow::flatbuf::Field& field)
106111
{
107112
const auto* float_type = field.type_as_FloatingPoint();
108113
const auto precision = float_type->precision();
109114
switch (precision)
110115
{
111-
case org::apache::arrow::flatbuf::Precision::HALF: return deserialize_primitive<sparrow::float16_t>(record_batch, body, name, metadata, nullable, buffer_index, field);
112-
case org::apache::arrow::flatbuf::Precision::SINGLE: return deserialize_primitive<float>(record_batch, body, name, metadata, nullable, buffer_index, field);
113-
case org::apache::arrow::flatbuf::Precision::DOUBLE: return deserialize_primitive<double>(record_batch, body, name, metadata, nullable, buffer_index, field);
116+
case org::apache::arrow::flatbuf::Precision::HALF: return deserialize_primitive<sparrow::float16_t>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
117+
case org::apache::arrow::flatbuf::Precision::SINGLE: return deserialize_primitive<float>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
118+
case org::apache::arrow::flatbuf::Precision::DOUBLE: return deserialize_primitive<double>(record_batch, body, name, metadata, nullable, buffer_index, variadic_counts_idx, field);
114119
default: throw std::runtime_error("Unsupported floating point precision: " + std::to_string(static_cast<int>(precision)));
115120
}
116121
}
@@ -121,6 +126,7 @@ namespace sparrow_ipc
121126
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
122127
bool nullable,
123128
size_t& buffer_index,
129+
size_t&,
124130
const org::apache::arrow::flatbuf::Field& field)
125131
{
126132
const auto* fixed_size_binary_field = field.type_as_FixedSizeBinary();
@@ -136,6 +142,7 @@ namespace sparrow_ipc
136142
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
137143
bool nullable,
138144
size_t& buffer_index,
145+
size_t&,
139146
const org::apache::arrow::flatbuf::Field& field)
140147
{
141148
const auto* decimal_field = field.type_as_Decimal();
@@ -171,6 +178,7 @@ namespace sparrow_ipc
171178
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
172179
bool nullable,
173180
size_t& buffer_index,
181+
size_t&,
174182
const org::apache::arrow::flatbuf::Field&)
175183
{
176184
return sparrow::array(deserialize_null_array(
@@ -184,6 +192,7 @@ namespace sparrow_ipc
184192
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
185193
bool nullable,
186194
size_t& buffer_index,
195+
size_t&,
187196
const org::apache::arrow::flatbuf::Field& field)
188197
{
189198
const auto date_type = field.type_as_Date();
@@ -202,6 +211,7 @@ namespace sparrow_ipc
202211
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
203212
bool nullable,
204213
size_t& buffer_index,
214+
size_t&,
205215
const org::apache::arrow::flatbuf::Field& field)
206216
{
207217
const auto* interval_type = field.type_as_Interval();
@@ -221,6 +231,7 @@ namespace sparrow_ipc
221231
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
222232
bool nullable,
223233
size_t& buffer_index,
234+
size_t&,
224235
const org::apache::arrow::flatbuf::Field& field)
225236
{
226237
const auto* duration_type = field.type_as_Duration();
@@ -241,6 +252,7 @@ namespace sparrow_ipc
241252
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
242253
bool nullable,
243254
size_t& buffer_index,
255+
size_t&,
244256
const org::apache::arrow::flatbuf::Field& field)
245257
{
246258
const auto time_type = field.type_as_Time();
@@ -261,6 +273,7 @@ namespace sparrow_ipc
261273
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
262274
bool nullable,
263275
size_t& buffer_index,
276+
size_t&,
264277
const org::apache::arrow::flatbuf::Field& field)
265278
{
266279
const auto timestamp_type = field.type_as_Timestamp();

0 commit comments

Comments
 (0)