Skip to content

Commit 8ebe859

Browse files
committed
Start on binary copy rewrite
1 parent cff1189 commit 8ebe859

File tree

4 files changed

+242
-0
lines changed

4 files changed

+242
-0
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ members = [
99
"postgres-protocol",
1010
"postgres-types",
1111
"tokio-postgres",
12+
"tokio-postgres-binary-copy",
1213
]
1314

1415
[profile.release]

tokio-postgres-binary-copy/Cargo.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[package]
2+
name = "tokio-postgres-binary-copy"
3+
version = "0.1.0"
4+
authors = ["Steven Fackler <[email protected]>"]
5+
edition = "2018"
6+
7+
[dependencies]
8+
bytes = "0.4"
9+
futures-preview = "=0.3.0-alpha.19"
10+
parking_lot = "0.9"
11+
pin-project-lite = "0.1"
12+
tokio-postgres = { version = "=0.5.0-alpha.1", default-features = false, path = "../tokio-postgres" }
13+
14+
[dev-dependencies]
15+
tokio = "=0.2.0-alpha.6"
16+
tokio-postgres = { version = "=0.5.0-alpha.1", path = "../tokio-postgres" }
17+

tokio-postgres-binary-copy/src/lib.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
use bytes::{BigEndian, BufMut, ByteOrder, Bytes, BytesMut};
2+
use futures::{future, Stream};
3+
use parking_lot::Mutex;
4+
use pin_project_lite::pin_project;
5+
use std::convert::TryFrom;
6+
use std::error::Error;
7+
use std::future::Future;
8+
use std::pin::Pin;
9+
use std::sync::Arc;
10+
use std::task::{Context, Poll};
11+
use tokio_postgres::types::{IsNull, ToSql, Type};
12+
13+
#[cfg(test)]
14+
mod test;
15+
16+
const BLOCK_SIZE: usize = 4096;
17+
18+
pin_project! {
19+
pub struct BinaryCopyStream<F> {
20+
#[pin]
21+
future: F,
22+
buf: Arc<Mutex<BytesMut>>,
23+
done: bool,
24+
}
25+
}
26+
27+
impl<F> BinaryCopyStream<F>
28+
where
29+
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
30+
{
31+
pub fn new<M>(types: &[Type], write_values: M) -> BinaryCopyStream<F>
32+
where
33+
M: FnOnce(BinaryCopyWriter) -> F,
34+
{
35+
let mut buf = BytesMut::new();
36+
buf.reserve(11 + 4 + 4);
37+
buf.put_slice(b"PGCOPY\n\xff\r\n\0"); // magic
38+
buf.put_i32_be(0); // flags
39+
buf.put_i32_be(0); // header extension
40+
41+
let buf = Arc::new(Mutex::new(buf));
42+
let writer = BinaryCopyWriter {
43+
buf: buf.clone(),
44+
types: types.to_vec(),
45+
idx: 0,
46+
};
47+
48+
BinaryCopyStream {
49+
future: write_values(writer),
50+
buf,
51+
done: false,
52+
}
53+
}
54+
}
55+
56+
impl<F> Stream for BinaryCopyStream<F>
57+
where
58+
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
59+
{
60+
type Item = Result<Bytes, Box<dyn Error + Sync + Send>>;
61+
62+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63+
let this = self.project();
64+
65+
if *this.done {
66+
return Poll::Ready(None);
67+
}
68+
69+
*this.done = this.future.poll(cx)?.is_ready();
70+
71+
let mut buf = this.buf.lock();
72+
if *this.done {
73+
buf.reserve(2);
74+
buf.put_i16_be(-1);
75+
Poll::Ready(Some(Ok(buf.take().freeze())))
76+
} else if buf.len() > BLOCK_SIZE {
77+
Poll::Ready(Some(Ok(buf.take().freeze())))
78+
} else {
79+
Poll::Pending
80+
}
81+
}
82+
}
83+
84+
// FIXME this should really just take a reference to the buffer, but that requires HKT :(
85+
pub struct BinaryCopyWriter {
86+
buf: Arc<Mutex<BytesMut>>,
87+
types: Vec<Type>,
88+
idx: usize,
89+
}
90+
91+
impl BinaryCopyWriter {
92+
pub async fn write(
93+
&mut self,
94+
value: &(dyn ToSql + Send),
95+
) -> Result<(), Box<dyn Error + Sync + Send>> {
96+
future::poll_fn(|_| {
97+
if self.buf.lock().len() > BLOCK_SIZE {
98+
Poll::Pending
99+
} else {
100+
Poll::Ready(())
101+
}
102+
})
103+
.await;
104+
105+
let mut buf = self.buf.lock();
106+
if self.idx == 0 {
107+
buf.reserve(2);
108+
buf.put_i16_be(self.types.len() as i16);
109+
}
110+
let idx = buf.len();
111+
buf.reserve(4);
112+
buf.put_i32_be(0);
113+
let len = match value.to_sql_checked(&self.types[self.idx], &mut buf)? {
114+
IsNull::Yes => -1,
115+
IsNull::No => i32::try_from(buf.len() - idx - 4)?,
116+
};
117+
BigEndian::write_i32(&mut buf[idx..], len);
118+
119+
self.idx = (self.idx + 1) % self.types.len();
120+
121+
Ok(())
122+
}
123+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
use crate::BinaryCopyStream;
2+
use tokio_postgres::types::Type;
3+
use tokio_postgres::{Client, NoTls};
4+
5+
async fn connect() -> Client {
6+
let (client, connection) =
7+
tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls)
8+
.await
9+
.unwrap();
10+
tokio::spawn(async {
11+
connection.await.unwrap();
12+
});
13+
client
14+
}
15+
16+
#[tokio::test]
17+
async fn write_basic() {
18+
let client = connect().await;
19+
20+
client
21+
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
22+
.await
23+
.unwrap();
24+
25+
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| {
26+
async move {
27+
w.write(&1i32).await?;
28+
w.write(&"foobar").await?;
29+
30+
w.write(&2i32).await?;
31+
w.write(&None::<&str>).await?;
32+
33+
Ok(())
34+
}
35+
});
36+
37+
client
38+
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
39+
.await
40+
.unwrap();
41+
42+
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
43+
assert_eq!(rows.len(), 2);
44+
assert_eq!(rows[0].get::<_, i32>(0), 1);
45+
assert_eq!(rows[0].get::<_, Option<&str>>(1), Some("foobar"));
46+
assert_eq!(rows[1].get::<_, i32>(0), 2);
47+
assert_eq!(rows[1].get::<_, Option<&str>>(1), None);
48+
}
49+
50+
#[tokio::test]
51+
async fn write_many_rows() {
52+
let client = connect().await;
53+
54+
client
55+
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
56+
.await
57+
.unwrap();
58+
59+
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| async move {
60+
for i in 0..10_000i32 {
61+
w.write(&i).await?;
62+
w.write(&format!("the value for {}", i)).await?;
63+
}
64+
65+
Ok(())
66+
});
67+
68+
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
69+
70+
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
71+
for (i, row) in rows.iter().enumerate() {
72+
assert_eq!(row.get::<_, i32>(0), i as i32);
73+
assert_eq!(row.get::<_, &str>(1), format!("the value for {}", i));
74+
}
75+
}
76+
77+
#[tokio::test]
78+
async fn write_big_rows() {
79+
let client = connect().await;
80+
81+
client.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)").await.unwrap();
82+
83+
let stream = BinaryCopyStream::new(&[Type::INT4, Type::BYTEA], |mut w| {
84+
async move {
85+
for i in 0..2i32 {
86+
w.write(&i).await.unwrap();
87+
w.write(&vec![i as u8; 128 * 1024]).await.unwrap();
88+
}
89+
90+
Ok(())
91+
}
92+
});
93+
94+
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
95+
96+
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
97+
for (i, row) in rows.iter().enumerate() {
98+
assert_eq!(row.get::<_, i32>(0), i as i32);
99+
assert_eq!(row.get::<_, &[u8]>(1), &*vec![i as u8; 128 * 1024]);
100+
}
101+
}

0 commit comments

Comments
 (0)