Skip to content

Commit c5cfe1a

Browse files
author
longshan.lu
committed
feat(arrow): enhance ArrowAssoc and Postgres transport for Float32Array support
- Implemented ArrowAssoc for Vec<Option<f32>> with appropriate builders and append methods. - Updated ArrowTypeSystem to include Float32Array mapping. - Enhanced Postgres transport mappings to support conversion for Vector, HalfVector, and SparseVector to Float32Array. - Added type conversion implementations for IpInet and Bit types in PostgresArrowTransport.
1 parent 610ce21 commit c5cfe1a

File tree

3 files changed

+160
-36
lines changed

3 files changed

+160
-36
lines changed

connectorx/src/destinations/arrowstream/arrow_assoc.rs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ use crate::constants::{DEFAULT_ARROW_DECIMAL, DEFAULT_ARROW_DECIMAL_SCALE, SECON
33
use crate::utils::decimal_to_i128;
44
use arrow::array::{
55
ArrayBuilder, BooleanBuilder, Date32Builder, Date64Builder, Decimal128Builder, Float32Builder,
6-
Float64Builder, Int32Builder, Int64Builder, LargeBinaryBuilder, StringBuilder,
7-
Time64NanosecondBuilder, TimestampNanosecondBuilder, UInt32Builder, UInt64Builder,
6+
Float64Builder, Int32Builder, Int64Builder, LargeBinaryBuilder, LargeListBuilder,
7+
StringBuilder, Time64NanosecondBuilder, TimestampNanosecondBuilder, UInt32Builder,
8+
UInt64Builder,
89
};
910
use arrow::datatypes::Field;
1011
use arrow::datatypes::{DataType as ArrowDataType, TimeUnit};
@@ -384,3 +385,51 @@ impl ArrowAssoc for Vec<u8> {
384385
Field::new(header, ArrowDataType::LargeBinary, false)
385386
}
386387
}
388+
389+
macro_rules! impl_arrow_array_assoc {
390+
($T:ty, $AT:expr, $B:ident) => {
391+
impl ArrowAssoc for $T {
392+
type Builder = LargeListBuilder<$B>;
393+
394+
fn builder(nrows: usize) -> Self::Builder {
395+
LargeListBuilder::with_capacity($B::new(), nrows)
396+
}
397+
398+
#[throws(ArrowDestinationError)]
399+
fn append(builder: &mut Self::Builder, value: Self) {
400+
builder.append_value(value);
401+
}
402+
403+
fn field(header: &str) -> Field {
404+
Field::new(
405+
header,
406+
ArrowDataType::LargeList(std::sync::Arc::new(Field::new_list_field($AT, true))),
407+
false,
408+
)
409+
}
410+
}
411+
412+
impl ArrowAssoc for Option<$T> {
413+
type Builder = LargeListBuilder<$B>;
414+
415+
fn builder(nrows: usize) -> Self::Builder {
416+
LargeListBuilder::with_capacity($B::new(), nrows)
417+
}
418+
419+
#[throws(ArrowDestinationError)]
420+
fn append(builder: &mut Self::Builder, value: Self) {
421+
builder.append_option(value);
422+
}
423+
424+
fn field(header: &str) -> Field {
425+
Field::new(
426+
header,
427+
ArrowDataType::LargeList(std::sync::Arc::new(Field::new_list_field($AT, true))),
428+
true,
429+
)
430+
}
431+
}
432+
};
433+
}
434+
435+
impl_arrow_array_assoc!(Vec<Option<f32>>, ArrowDataType::Float32, Float32Builder);

connectorx/src/destinations/arrowstream/typesystem.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,26 @@ pub enum ArrowTypeSystem {
1818
Date64(bool),
1919
Time64(bool),
2020
DateTimeTz(bool),
21+
Float32Array(bool),
2122
}
2223

2324
impl_typesystem! {
2425
system = ArrowTypeSystem,
2526
mappings = {
26-
{ Int32 => i32 }
27-
{ Int64 => i64 }
28-
{ UInt32 => u32 }
29-
{ UInt64 => u64 }
30-
{ Float64 => f64 }
31-
{ Float32 => f32 }
32-
{ Decimal => Decimal }
33-
{ Boolean => bool }
34-
{ LargeUtf8 => String }
35-
{ LargeBinary => Vec<u8> }
36-
{ Date32 => NaiveDate }
37-
{ Date64 => NaiveDateTime }
38-
{ Time64 => NaiveTime }
39-
{ DateTimeTz => DateTime<Utc> }
27+
{ Int32 => i32 }
28+
{ Int64 => i64 }
29+
{ UInt32 => u32 }
30+
{ UInt64 => u64 }
31+
{ Float64 => f64 }
32+
{ Float32 => f32 }
33+
{ Decimal => Decimal }
34+
{ Boolean => bool }
35+
{ LargeUtf8 => String }
36+
{ LargeBinary => Vec<u8> }
37+
{ Date32 => NaiveDate }
38+
{ Date64 => NaiveDateTime }
39+
{ Time64 => NaiveTime }
40+
{ DateTimeTz => DateTime<Utc> }
41+
{ Float32Array => Vec<Option<f32>> }
4042
}
4143
}

connectorx/src/transports/postgres_arrowstream.rs

Lines changed: 93 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use crate::sources::postgres::{
99
};
1010
use crate::typesystem::TypeConversion;
1111
use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
12+
use cidr_02::IpInet;
13+
use pgvector::{Bit, HalfVector, SparseVector, Vector};
1214
use postgres::NoTls;
1315
use postgres_openssl::MakeTlsConnector;
1416
use rust_decimal::Decimal;
@@ -40,26 +42,31 @@ macro_rules! impl_postgres_transport {
4042
systems = PostgresTypeSystem => ArrowTypeSystem,
4143
route = PostgresSource<$proto, $tls> => ArrowDestination,
4244
mappings = {
43-
{ Float4[f32] => Float64[f64] | conversion auto }
44-
{ Float8[f64] => Float64[f64] | conversion auto }
45-
{ Numeric[Decimal] => Decimal[Decimal] | conversion auto }
46-
{ Int2[i16] => Int64[i64] | conversion auto }
47-
{ Int4[i32] => Int64[i64] | conversion auto }
48-
{ Int8[i64] => Int64[i64] | conversion auto }
49-
{ Bool[bool] => Boolean[bool] | conversion auto }
50-
{ Text[&'r str] => LargeUtf8[String] | conversion owned }
51-
{ BpChar[&'r str] => LargeUtf8[String] | conversion none }
52-
{ VarChar[&'r str] => LargeUtf8[String] | conversion none }
53-
{ Name[&'r str] => LargeUtf8[String] | conversion none }
54-
{ Timestamp[NaiveDateTime] => Date64[NaiveDateTime] | conversion auto }
55-
{ Date[NaiveDate] => Date32[NaiveDate] | conversion auto }
56-
{ Time[NaiveTime] => Time64[NaiveTime] | conversion auto }
57-
{ TimestampTz[DateTime<Utc>] => DateTimeTz[DateTime<Utc>] | conversion auto }
58-
{ UUID[Uuid] => LargeUtf8[String] | conversion option }
59-
{ Char[&'r str] => LargeUtf8[String] | conversion none }
60-
{ ByteA[Vec<u8>] => LargeBinary[Vec<u8>] | conversion auto }
61-
{ JSON[Value] => LargeUtf8[String] | conversion option }
62-
{ JSONB[Value] => LargeUtf8[String] | conversion none }
45+
{ Float4[f32] => Float64[f64] | conversion auto }
46+
{ Float8[f64] => Float64[f64] | conversion auto }
47+
{ Numeric[Decimal] => Decimal[Decimal] | conversion auto }
48+
{ Int2[i16] => Int64[i64] | conversion auto }
49+
{ Int4[i32] => Int64[i64] | conversion auto }
50+
{ Int8[i64] => Int64[i64] | conversion auto }
51+
{ Bool[bool] => Boolean[bool] | conversion auto }
52+
{ Text[&'r str] => LargeUtf8[String] | conversion owned }
53+
{ BpChar[&'r str] => LargeUtf8[String] | conversion none }
54+
{ VarChar[&'r str] => LargeUtf8[String] | conversion none }
55+
{ Name[&'r str] => LargeUtf8[String] | conversion none }
56+
{ Timestamp[NaiveDateTime] => Date64[NaiveDateTime] | conversion auto }
57+
{ Date[NaiveDate] => Date32[NaiveDate] | conversion auto }
58+
{ Time[NaiveTime] => Time64[NaiveTime] | conversion auto }
59+
{ TimestampTz[DateTime<Utc>] => DateTimeTz[DateTime<Utc>] | conversion auto }
60+
{ UUID[Uuid] => LargeUtf8[String] | conversion option }
61+
{ Char[&'r str] => LargeUtf8[String] | conversion none }
62+
{ ByteA[Vec<u8>] => LargeBinary[Vec<u8>] | conversion auto }
63+
{ JSON[Value] => LargeUtf8[String] | conversion option }
64+
{ JSONB[Value] => LargeUtf8[String] | conversion none }
65+
{ Inet[IpInet] => LargeUtf8[String] | conversion none }
66+
{ Vector[Vector] => Float32Array[Vec<Option<f32>>] | conversion none }
67+
{ HalfVec[HalfVector] => Float32Array[Vec<Option<f32>>] | conversion none }
68+
{ Bit[Bit] => LargeBinary[Vec<u8>] | conversion none }
69+
{ SparseVec[SparseVector] => Float32Array[Vec<Option<f32>>] | conversion none }
6370
}
6471
);
6572
}
@@ -74,6 +81,18 @@ impl_postgres_transport!(CursorProtocol, MakeTlsConnector);
7481
impl_postgres_transport!(SimpleProtocol, NoTls);
7582
impl_postgres_transport!(SimpleProtocol, MakeTlsConnector);
7683

84+
impl<P, C> TypeConversion<IpInet, String> for PostgresArrowTransport<P, C> {
85+
fn convert(val: IpInet) -> String {
86+
val.to_string()
87+
}
88+
}
89+
90+
impl<P, C> TypeConversion<Option<IpInet>, Option<String>> for PostgresArrowTransport<P, C> {
91+
fn convert(val: Option<IpInet>) -> Option<String> {
92+
val.map(|val| val.to_string())
93+
}
94+
}
95+
7796
impl<P, C> TypeConversion<Uuid, String> for PostgresArrowTransport<P, C> {
7897
fn convert(val: Uuid) -> String {
7998
val.to_string()
@@ -85,3 +104,57 @@ impl<P, C> TypeConversion<Value, String> for PostgresArrowTransport<P, C> {
85104
val.to_string()
86105
}
87106
}
107+
108+
impl<P, C> TypeConversion<Vector, Vec<Option<f32>>> for PostgresArrowTransport<P, C> {
109+
fn convert(val: Vector) -> Vec<Option<f32>> {
110+
val.to_vec().into_iter().map(Some).collect()
111+
}
112+
}
113+
114+
impl<P, C> TypeConversion<Option<Vector>, Option<Vec<Option<f32>>>>
115+
for PostgresArrowTransport<P, C>
116+
{
117+
fn convert(val: Option<Vector>) -> Option<Vec<Option<f32>>> {
118+
val.map(|val| val.to_vec().into_iter().map(Some).collect())
119+
}
120+
}
121+
122+
impl<P, C> TypeConversion<HalfVector, Vec<Option<f32>>> for PostgresArrowTransport<P, C> {
123+
fn convert(val: HalfVector) -> Vec<Option<f32>> {
124+
val.to_vec().into_iter().map(|v| Some(v.to_f32())).collect()
125+
}
126+
}
127+
128+
impl<P, C> TypeConversion<Option<HalfVector>, Option<Vec<Option<f32>>>>
129+
for PostgresArrowTransport<P, C>
130+
{
131+
fn convert(val: Option<HalfVector>) -> Option<Vec<Option<f32>>> {
132+
val.map(|val| val.to_vec().into_iter().map(|v| Some(v.to_f32())).collect())
133+
}
134+
}
135+
136+
impl<P, C> TypeConversion<Bit, Vec<u8>> for PostgresArrowTransport<P, C> {
137+
fn convert(val: Bit) -> Vec<u8> {
138+
val.as_bytes().into()
139+
}
140+
}
141+
142+
impl<P, C> TypeConversion<Option<Bit>, Option<Vec<u8>>> for PostgresArrowTransport<P, C> {
143+
fn convert(val: Option<Bit>) -> Option<Vec<u8>> {
144+
val.map(|val| val.as_bytes().into())
145+
}
146+
}
147+
148+
impl<P, C> TypeConversion<SparseVector, Vec<Option<f32>>> for PostgresArrowTransport<P, C> {
149+
fn convert(val: SparseVector) -> Vec<Option<f32>> {
150+
val.to_vec().into_iter().map(Some).collect()
151+
}
152+
}
153+
154+
impl<P, C> TypeConversion<Option<SparseVector>, Option<Vec<Option<f32>>>>
155+
for PostgresArrowTransport<P, C>
156+
{
157+
fn convert(val: Option<SparseVector>) -> Option<Vec<Option<f32>>> {
158+
val.map(|val| val.to_vec().into_iter().map(Some).collect())
159+
}
160+
}

0 commit comments

Comments
 (0)