Skip to content

Commit 23c8966

Browse files
committed
feat: implement records to struct conversion with FromSql derive
1 parent e1cd6be commit 23c8966

File tree

5 files changed

+261
-40
lines changed

5 files changed

+261
-40
lines changed

postgres-derive-test/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::fmt;
77
mod composites;
88
mod domains;
99
mod enums;
10+
mod records;
1011
mod transparent;
1112

1213
pub fn test_type<T, S>(conn: &mut Client, sql_type: &str, checks: &[(T, S)])

postgres-derive-test/src/records.rs

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
use postgres::{Client, NoTls};
2+
use postgres_types::{FromSql, ToSql, WrongType};
3+
use std::error::Error;
4+
5+
#[test]
6+
fn basic() {
7+
#[derive(FromSql, ToSql, Debug, PartialEq)]
8+
struct InventoryItem {
9+
name: String,
10+
supplier_id: i32,
11+
price: Option<f64>,
12+
}
13+
14+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
15+
16+
let expected = InventoryItem {
17+
name: "foobar".to_owned(),
18+
supplier_id: 100,
19+
price: Some(15.50),
20+
};
21+
22+
let got = conn
23+
.query_one("SELECT ('foobar', 100, 15.50::double precision)", &[])
24+
.unwrap()
25+
.try_get::<_, InventoryItem>(0)
26+
.unwrap();
27+
28+
assert_eq!(got, expected);
29+
}
30+
31+
#[test]
32+
fn field_count_mismatch() {
33+
#[derive(FromSql, Debug, PartialEq)]
34+
struct InventoryItem {
35+
name: String,
36+
supplier_id: i32,
37+
price: Option<f64>,
38+
}
39+
40+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
41+
42+
let err = conn
43+
.query_one("SELECT ('foobar', 100)", &[])
44+
.unwrap()
45+
.try_get::<_, InventoryItem>(0)
46+
.unwrap_err();
47+
err.source().unwrap().is::<WrongType>();
48+
49+
let err = conn
50+
.query_one("SELECT ('foobar', 100, 15.50, 'extra')", &[])
51+
.unwrap()
52+
.try_get::<_, InventoryItem>(0)
53+
.unwrap_err();
54+
err.source().unwrap().is::<WrongType>();
55+
}
56+
57+
#[test]
58+
fn wrong_type() {
59+
#[derive(FromSql, Debug, PartialEq)]
60+
struct InventoryItem {
61+
name: String,
62+
supplier_id: i32,
63+
price: Option<f64>,
64+
}
65+
66+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
67+
68+
let err = conn
69+
.query_one("SELECT ('foobar', 'not_an_int', 15.50)", &[])
70+
.unwrap()
71+
.try_get::<_, InventoryItem>(0)
72+
.unwrap_err();
73+
assert!(err.source().unwrap().is::<WrongType>());
74+
75+
let err = conn
76+
.query_one("SELECT (123, 100, 15.50)", &[])
77+
.unwrap()
78+
.try_get::<_, InventoryItem>(0)
79+
.unwrap_err();
80+
assert!(err.source().unwrap().is::<WrongType>());
81+
}
82+
83+
#[test]
84+
fn nested_structs() {
85+
#[derive(FromSql, Debug, PartialEq)]
86+
struct Address {
87+
street: String,
88+
city: Option<String>,
89+
}
90+
91+
#[derive(FromSql, Debug, PartialEq)]
92+
struct Person {
93+
name: String,
94+
age: Option<i32>,
95+
address: Address,
96+
}
97+
98+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
99+
100+
let result: Person = conn
101+
.query_one(
102+
"SELECT ('John', 30, ROW('123 Main St', 'Springfield'))",
103+
&[],
104+
)
105+
.unwrap()
106+
.get(0);
107+
108+
let expected = Person {
109+
name: "John".to_owned(),
110+
age: Some(30),
111+
address: Address {
112+
street: "123 Main St".to_owned(),
113+
city: Some("Springfield".to_owned()),
114+
},
115+
};
116+
117+
assert_eq!(result, expected);
118+
}
119+
120+
#[test]
121+
fn generics() {
122+
#[derive(FromSql, ToSql, Debug, PartialEq)]
123+
struct GenericItem<T, U> {
124+
first: T,
125+
second: U,
126+
}
127+
128+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
129+
130+
let expected = GenericItem {
131+
first: "test".to_owned(),
132+
second: 42,
133+
};
134+
135+
let got = conn
136+
.query_one("SELECT ('test', 42)", &[])
137+
.unwrap()
138+
.try_get::<_, GenericItem<String, i32>>(0)
139+
.unwrap();
140+
141+
assert_eq!(got, expected);
142+
}

postgres-derive/src/accepts.rs

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,47 @@ pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> Toke
6666
}
6767
}
6868

69-
pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream {
69+
pub fn composite_body_from_sql(name: &str, fields: &[Field]) -> TokenStream {
7070
let num_fields = fields.len();
71-
let trait_ = Ident::new(trait_, Span::call_site());
71+
let trait_ = Ident::new("FromSql", Span::call_site());
7272
let traits = iter::repeat(&trait_);
7373
let field_names = fields.iter().map(|f| &f.name);
7474
let field_types = fields.iter().map(|f| &f.type_);
7575

7676
quote! {
77-
if type_.name() != #name {
78-
return false;
77+
match *type_.kind() {
78+
::postgres_types::Kind::Composite(ref fields) if type_.name() == #name => {
79+
if fields.len() != #num_fields {
80+
return false;
81+
}
82+
83+
fields.iter().all(|f| {
84+
match f.name() {
85+
#(
86+
#field_names => {
87+
<#field_types as ::postgres_types::#traits>::accepts(f.type_())
88+
}
89+
)*
90+
_ => false,
91+
}
92+
})
93+
},
94+
::postgres_types::Kind::Pseudo => true,
95+
_ => false,
7996
}
97+
}
98+
}
8099

100+
pub fn composite_body_to_sql(name: &str, fields: &[Field]) -> TokenStream {
101+
let num_fields = fields.len();
102+
let trait_ = Ident::new("ToSql", Span::call_site());
103+
let traits = iter::repeat(&trait_);
104+
let field_names = fields.iter().map(|f| &f.name);
105+
let field_types = fields.iter().map(|f| &f.type_);
106+
107+
quote! {
81108
match *type_.kind() {
82-
::postgres_types::Kind::Composite(ref fields) => {
109+
::postgres_types::Kind::Composite(ref fields) if type_.name() == #name => {
83110
if fields.len() != #num_fields {
84111
return false;
85112
}
@@ -94,7 +121,7 @@ pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream
94121
_ => false,
95122
}
96123
})
97-
}
124+
},
98125
_ => false,
99126
}
100127
}

postgres-derive/src/fromsql.rs

Lines changed: 83 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
101101
.map(|field| Field::parse(field, overrides.rename_all))
102102
.collect::<Result<Vec<_>, _>>()?;
103103
(
104-
accepts::composite_body(&name, "FromSql", &fields),
104+
accepts::composite_body_from_sql(&name, &fields),
105105
composite_body(&input.ident, &fields),
106106
)
107107
}
@@ -191,45 +191,96 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
191191
let field_names = &fields.iter().map(|f| &f.name).collect::<Vec<_>>();
192192
let field_idents = &fields.iter().map(|f| &f.ident).collect::<Vec<_>>();
193193

194+
let field_types = &fields.iter().map(|f| &f.type_).collect::<Vec<_>>();
195+
let field_indices = (0..fields.len()).collect::<Vec<_>>();
196+
let field_count = fields.len();
197+
194198
quote! {
195-
let fields = match *_type.kind() {
196-
postgres_types::Kind::Composite(ref fields) => fields,
197-
_ => unreachable!(),
198-
};
199+
match *_type.kind() {
200+
postgres_types::Kind::Composite(ref fields) => {
201+
let mut buf = buf;
202+
let num_fields = postgres_types::private::read_be_i32(&mut buf)?;
203+
if num_fields as usize != fields.len() {
204+
return std::result::Result::Err(std::convert::Into::into(format!(
205+
"invalid field count: {} vs {}",
206+
num_fields,
207+
fields.len()
208+
)));
209+
}
199210

200-
let mut buf = buf;
201-
let num_fields = postgres_types::private::read_be_i32(&mut buf)?;
202-
if num_fields as usize != fields.len() {
203-
return std::result::Result::Err(
204-
std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields, fields.len())));
205-
}
211+
#(
212+
let mut #temp_vars = std::option::Option::None;
213+
)*
206214

207-
#(
208-
let mut #temp_vars = std::option::Option::None;
209-
)*
215+
for field in fields {
216+
let oid = postgres_types::private::read_be_i32(&mut buf)? as u32;
217+
if oid != field.type_().oid() {
218+
return std::result::Result::Err(std::convert::Into::into("unexpected OID"));
219+
}
210220

211-
for field in fields {
212-
let oid = postgres_types::private::read_be_i32(&mut buf)? as u32;
213-
if oid != field.type_().oid() {
214-
return std::result::Result::Err(std::convert::Into::into("unexpected OID"));
215-
}
221+
match field.name() {
222+
#(
223+
#field_names => {
224+
#temp_vars = std::option::Option::Some(
225+
postgres_types::private::read_value(field.type_(), &mut buf)?,
226+
);
227+
}
228+
)*
229+
_ => unreachable!(),
230+
}
231+
}
232+
233+
std::result::Result::Ok(#ident {
234+
#(
235+
#field_idents: #temp_vars.unwrap(),
236+
)*
237+
})
238+
},
239+
postgres_types::Kind::Pseudo if *_type == postgres_types::Type::RECORD => {
240+
let mut buf = buf;
241+
let num_fields = postgres_types::private::read_be_i32(&mut buf)?;
242+
if num_fields as usize != #field_count {
243+
return std::result::Result::Err(
244+
std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields, #field_count)));
245+
}
216246

217-
match field.name() {
218247
#(
219-
#field_names => {
220-
#temp_vars = std::option::Option::Some(
221-
postgres_types::private::read_value(field.type_(), &mut buf)?);
222-
}
248+
let mut #temp_vars = std::option::Option::None;
223249
)*
224-
_ => unreachable!(),
225-
}
226-
}
227250

228-
std::result::Result::Ok(#ident {
229-
#(
230-
#field_idents: #temp_vars.unwrap(),
231-
)*
232-
})
251+
for i in 0..num_fields {
252+
let oid = postgres_types::private::read_be_i32(&mut buf)? as u32;
253+
let ty = match postgres_types::Type::from_oid(oid) {
254+
std::option::Option::None => {
255+
return std::result::Result::Err(std::convert::Into::into(
256+
format!("cannot decode OID {} inside of anonymous record", oid)));
257+
}
258+
std::option::Option::Some(ty) => ty,
259+
};
260+
261+
match i as usize {
262+
#(
263+
#field_indices => {
264+
if !<#field_types as postgres_types::FromSql>::accepts(&ty) {
265+
return std::result::Result::Err(std::boxed::Box::new(
266+
postgres_types::WrongType::new::<#field_types>(ty.clone())));
267+
}
268+
#temp_vars = std::option::Option::Some(
269+
postgres_types::private::read_value(&ty, &mut buf)?);
270+
}
271+
)*
272+
_ => unreachable!(),
273+
}
274+
}
275+
276+
std::result::Result::Ok(#ident {
277+
#(
278+
#field_idents: #temp_vars.unwrap(),
279+
)*
280+
})
281+
},
282+
_ => unreachable!(),
283+
}
233284
}
234285
}
235286

postgres-derive/src/tosql.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
9595
.map(|field| Field::parse(field, overrides.rename_all))
9696
.collect::<Result<Vec<_>, _>>()?;
9797
(
98-
accepts::composite_body(&name, "ToSql", &fields),
98+
accepts::composite_body_to_sql(&name, &fields),
9999
composite_body(&fields),
100100
)
101101
}
@@ -112,7 +112,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
112112
let generics = append_generic_bound(input.generics.to_owned(), &new_tosql_bound());
113113
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
114114
let out = quote! {
115-
impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause {
115+
impl #impl_generics postgres_types::ToSql for #ident #ty_generics #where_clause {
116116
fn to_sql(&self,
117117
_type: &postgres_types::Type,
118118
buf: &mut postgres_types::private::BytesMut)

0 commit comments

Comments
 (0)