Skip to content

Commit 610ce21

Browse files
author
longshan.lu
committed
feat(postgres): integrate pgvector support in Postgres transport
- Added support for pgvector types (Vector, HalfVector, Bit, SparseVector) in the Postgres transport. - Updated Cargo.toml and Cargo.lock to include the pgvector dependency. - Enhanced type conversion for pgvector types in PostgresArrowTransport. - Created SQL scripts to define and populate a new table for vector types. - Added tests to verify the handling of pgvector types in the Postgres integration.
1 parent c474f6a commit 610ce21

File tree

7 files changed

+297
-268
lines changed

7 files changed

+297
-268
lines changed

Cargo.lock

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

connectorx/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ datafusion = {version = "46", optional = true}
6060
prusto = {version = "0.5", optional = true}
6161
serde = {version = "1", optional = true}
6262
cidr-02 = { version = "0.2", package = "cidr", optional = true }
63+
pgvector = { version = "0.4", features = [ "postgres", "halfvec"], optional = true }
6364

6465
[lib]
6566
crate-type = ["cdylib", "rlib"]
@@ -99,6 +100,7 @@ src_postgres = [
99100
"openssl",
100101
"postgres-openssl",
101102
"cidr-02",
103+
"pgvector",
102104
]
103105
src_sqlite = ["rusqlite", "r2d2_sqlite", "fallible-streaming-iterator", "r2d2", "urlencoding"]
104106
src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits", "serde"]

connectorx/src/sources/postgres/mod.rs

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod typesystem;
77
pub use self::errors::PostgresSourceError;
88
use cidr_02::IpInet;
99
pub use connection::rewrite_tls_args;
10+
use pgvector::{Bit, HalfVector, SparseVector, Vector};
1011
pub use typesystem::{PostgresTypePairs, PostgresTypeSystem};
1112

1213
use crate::constants::DB_BUFFER_SIZE;
@@ -52,6 +53,47 @@ pub enum SimpleProtocol {}
5253
type PgManager<C> = PostgresConnectionManager<C>;
5354
type PgConn<C> = PooledConnection<PgManager<C>>;
5455

56+
macro_rules! impl_produce_unimplemented {
57+
($(($protocol: ty, $t: ty, $msg: expr),)+) => {
58+
$(
59+
impl<'r> Produce<'r, $t> for $protocol {
60+
type Error = PostgresSourceError;
61+
62+
#[throws(PostgresSourceError)]
63+
fn produce(&'r mut self) -> $t {
64+
unimplemented!($msg);
65+
}
66+
}
67+
68+
impl<'r> Produce<'r, Option<$t>> for $protocol {
69+
type Error = PostgresSourceError;
70+
71+
#[throws(PostgresSourceError)]
72+
fn produce(&'r mut self) -> Option<$t> {
73+
unimplemented!($msg);
74+
}
75+
}
76+
)+
77+
};
78+
}
79+
80+
impl_produce_unimplemented!(
81+
(PostgresCSVSourceParser<'_>, HashMap<String, Option<String>>, "Please use `cursor` protocol for hstore type"),
82+
(PostgresCSVSourceParser<'_>, Vector, "Please use `binary` protocol for vector type"),
83+
(PostgresCSVSourceParser<'_>, HalfVector, "Please use `binary` protocol for halfvector type"),
84+
(PostgresCSVSourceParser<'_>, Bit, "Please use `binary` protocol for bit type"),
85+
(PostgresCSVSourceParser<'_>, SparseVector, "Please use `binary` protocol for sparsevector type"),
86+
87+
88+
(PostgresSimpleSourceParser,HashMap<String, Option<String>>, "unimplemented"),
89+
(PostgresSimpleSourceParser,Value, "unimplemented"),
90+
(PostgresSimpleSourceParser, Vector, "Please use `binary` protocol for vector type"),
91+
(PostgresSimpleSourceParser, HalfVector, "Please use `binary` protocol for halfvector type"),
92+
(PostgresSimpleSourceParser, Bit, "Please use `binary` protocol for bit type"),
93+
(PostgresSimpleSourceParser, SparseVector, "Please use `binary` protocol for sparsevector type"),
94+
95+
);
96+
5597
// take a row and unwrap the interior field from column 0
5698
fn convert_row<'b, R: TryFrom<usize> + postgres::types::FromSql<'b> + Clone>(row: &'b Row) -> R {
5799
let nrows: Option<R> = row.get(0);
@@ -482,6 +524,10 @@ impl_produce!(
482524
Uuid,
483525
Value,
484526
IpInet,
527+
Vector,
528+
HalfVector,
529+
Bit,
530+
SparseVector,
485531
Vec<Option<bool>>,
486532
Vec<Option<i16>>,
487533
Vec<Option<i32>>,
@@ -773,22 +819,6 @@ macro_rules! impl_csv_vec_produce {
773819

774820
impl_csv_vec_produce!(i8, i16, i32, i64, f32, f64, Decimal, String,);
775821

776-
impl Produce<'_, HashMap<String, Option<String>>> for PostgresCSVSourceParser<'_> {
777-
type Error = PostgresSourceError;
778-
#[throws(PostgresSourceError)]
779-
fn produce(&mut self) -> HashMap<String, Option<String>> {
780-
unimplemented!("Please use `cursor` protocol for hstore type");
781-
}
782-
}
783-
784-
impl Produce<'_, Option<HashMap<String, Option<String>>>> for PostgresCSVSourceParser<'_> {
785-
type Error = PostgresSourceError;
786-
#[throws(PostgresSourceError)]
787-
fn produce(&mut self) -> Option<HashMap<String, Option<String>>> {
788-
unimplemented!("Please use `cursor` protocol for hstore type");
789-
}
790-
}
791-
792822
impl Produce<'_, bool> for PostgresCSVSourceParser<'_> {
793823
type Error = PostgresSourceError;
794824

@@ -1219,6 +1249,10 @@ impl_produce!(
12191249
Uuid,
12201250
Value,
12211251
IpInet,
1252+
Vector,
1253+
HalfVector,
1254+
Bit,
1255+
SparseVector,
12221256
HashMap<String, Option<String>>,
12231257
Vec<Option<bool>>,
12241258
Vec<Option<String>>,
@@ -1403,30 +1437,6 @@ impl PartitionParser<'_> for PostgresSimpleSourceParser {
14031437
}
14041438
}
14051439

1406-
macro_rules! impl_simple_produce_unimplemented {
1407-
($($t: ty,)+) => {
1408-
$(
1409-
impl<'r, 'a> Produce<'r, $t> for PostgresSimpleSourceParser {
1410-
type Error = PostgresSourceError;
1411-
1412-
#[throws(PostgresSourceError)]
1413-
fn produce(&'r mut self) -> $t {
1414-
unimplemented!("not implemented!");
1415-
}
1416-
}
1417-
1418-
impl<'r, 'a> Produce<'r, Option<$t>> for PostgresSimpleSourceParser {
1419-
type Error = PostgresSourceError;
1420-
1421-
#[throws(PostgresSourceError)]
1422-
fn produce(&'r mut self) -> Option<$t> {
1423-
unimplemented!("not implemented!");
1424-
}
1425-
}
1426-
)+
1427-
};
1428-
}
1429-
14301440
macro_rules! impl_simple_produce {
14311441
($($t: ty,)+) => {
14321442
$(
@@ -1591,10 +1601,6 @@ impl<'r> Produce<'r, Option<Decimal>> for PostgresSimpleSourceParser {
15911601
}
15921602
}
15931603

1594-
impl_simple_produce_unimplemented!(
1595-
Value,
1596-
HashMap<String, Option<String>>,);
1597-
15981604
impl<'r> Produce<'r, &'r str> for PostgresSimpleSourceParser {
15991605
type Error = PostgresSourceError;
16001606

connectorx/src/sources/postgres/typesystem.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use serde_json::Value;
66
use std::collections::HashMap;
77
use uuid::Uuid;
88

9+
use pgvector::{Bit, HalfVector, SparseVector, Vector};
10+
911
#[derive(Copy, Clone, Debug)]
1012
pub enum PostgresTypeSystem {
1113
Bool(bool),
@@ -40,6 +42,10 @@ pub enum PostgresTypeSystem {
4042
HSTORE(bool),
4143
Name(bool),
4244
Inet(bool),
45+
Vector(bool),
46+
HalfVec(bool),
47+
Bit(bool),
48+
SparseVec(bool),
4349
}
4450

4551
impl_typesystem! {
@@ -71,6 +77,10 @@ impl_typesystem! {
7177
{ JSON | JSONB => Value }
7278
{ HSTORE => HashMap<String, Option<String>> }
7379
{ Inet => IpInet }
80+
{ Vector => Vector }
81+
{ HalfVec => HalfVector }
82+
{ Bit => Bit }
83+
{ SparseVec => SparseVector }
7484
}
7585
}
7686

@@ -108,6 +118,10 @@ impl<'a> From<&'a Type> for PostgresTypeSystem {
108118
"jsonb" => JSONB(true),
109119
"hstore" => HSTORE(true),
110120
"inet" => Inet(true),
121+
"vector" => Vector(true),
122+
"halfvec" => HalfVec(true),
123+
"bit" => Bit(true),
124+
"sparsevec" => SparseVec(true),
111125
_ => match ty.kind() {
112126
postgres::types::Kind::Enum(_) => Enum(true),
113127
_ => unimplemented!("{}", ty.name()),

connectorx/src/transports/postgres_arrow.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::typesystem::TypeConversion;
1414
use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
1515
use cidr_02::IpInet;
1616
use num_traits::ToPrimitive;
17+
use pgvector::{Bit, HalfVector, SparseVector, Vector};
1718
use postgres::NoTls;
1819
use postgres_openssl::MakeTlsConnector;
1920
use rust_decimal::Decimal;
@@ -76,6 +77,10 @@ macro_rules! impl_postgres_transport {
7677
{ Float4Array[Vec<Option<f32>>] => Float32Array[Vec<Option<f32>>] | conversion auto }
7778
{ Float8Array[Vec<Option<f64>>] => Float64Array[Vec<Option<f64>>] | conversion auto }
7879
{ NumericArray[Vec<Option<Decimal>>] => DecimalArray[Vec<Option<Decimal>>] | conversion auto }
80+
{ Vector[Vector] => Float32Array[Vec<Option<f32>>] | conversion none }
81+
{ HalfVec[HalfVector] => Float32Array[Vec<Option<f32>>] | conversion none }
82+
{ Bit[Bit] => LargeBinary[Vec<u8>] | conversion none }
83+
{ SparseVec[SparseVector] => Float32Array[Vec<Option<f32>>] | conversion none }
7984
}
8085
);
8186
}
@@ -140,3 +145,57 @@ impl<P, C> TypeConversion<Value, String> for PostgresArrowTransport<P, C> {
140145
val.to_string()
141146
}
142147
}
148+
149+
impl<P, C> TypeConversion<Vector, Vec<Option<f32>>> for PostgresArrowTransport<P, C> {
150+
fn convert(val: Vector) -> Vec<Option<f32>> {
151+
val.to_vec().into_iter().map(Some).collect()
152+
}
153+
}
154+
155+
impl<P, C> TypeConversion<Option<Vector>, Option<Vec<Option<f32>>>>
156+
for PostgresArrowTransport<P, C>
157+
{
158+
fn convert(val: Option<Vector>) -> Option<Vec<Option<f32>>> {
159+
val.map(|val| val.to_vec().into_iter().map(Some).collect())
160+
}
161+
}
162+
163+
impl<P, C> TypeConversion<HalfVector, Vec<Option<f32>>> for PostgresArrowTransport<P, C> {
164+
fn convert(val: HalfVector) -> Vec<Option<f32>> {
165+
val.to_vec().into_iter().map(|v| Some(v.to_f32())).collect()
166+
}
167+
}
168+
169+
impl<P, C> TypeConversion<Option<HalfVector>, Option<Vec<Option<f32>>>>
170+
for PostgresArrowTransport<P, C>
171+
{
172+
fn convert(val: Option<HalfVector>) -> Option<Vec<Option<f32>>> {
173+
val.map(|val| val.to_vec().into_iter().map(|v| Some(v.to_f32())).collect())
174+
}
175+
}
176+
177+
impl<P, C> TypeConversion<Bit, Vec<u8>> for PostgresArrowTransport<P, C> {
178+
fn convert(val: Bit) -> Vec<u8> {
179+
val.as_bytes().into()
180+
}
181+
}
182+
183+
impl<P, C> TypeConversion<Option<Bit>, Option<Vec<u8>>> for PostgresArrowTransport<P, C> {
184+
fn convert(val: Option<Bit>) -> Option<Vec<u8>> {
185+
val.map(|val| val.as_bytes().into())
186+
}
187+
}
188+
189+
impl<P, C> TypeConversion<SparseVector, Vec<Option<f32>>> for PostgresArrowTransport<P, C> {
190+
fn convert(val: SparseVector) -> Vec<Option<f32>> {
191+
val.to_vec().into_iter().map(Some).collect()
192+
}
193+
}
194+
195+
impl<P, C> TypeConversion<Option<SparseVector>, Option<Vec<Option<f32>>>>
196+
for PostgresArrowTransport<P, C>
197+
{
198+
fn convert(val: Option<SparseVector>) -> Option<Vec<Option<f32>>> {
199+
val.map(|val| val.to_vec().into_iter().map(Some).collect())
200+
}
201+
}

0 commit comments

Comments
 (0)