Skip to content

Commit a254e6e

Browse files
committed
Blocking binary copy support
1 parent cc8d8fe commit a254e6e

File tree

9 files changed

+201
-33
lines changed

9 files changed

+201
-33
lines changed

postgres/src/binary_copy.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//! Utilities for working with the PostgreSQL binary copy format.
2+
3+
use crate::types::{ToSql, Type};
4+
use crate::{CopyInWriter, CopyOutReader, Error, Rt};
5+
use fallible_iterator::FallibleIterator;
6+
use futures::StreamExt;
7+
use std::pin::Pin;
8+
#[doc(inline)]
9+
pub use tokio_postgres::binary_copy::BinaryCopyOutRow;
10+
use tokio_postgres::binary_copy::{self, BinaryCopyOutStream};
11+
12+
/// A type which serializes rows into the PostgreSQL binary copy format.
13+
///
14+
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
15+
pub struct BinaryCopyInWriter<'a> {
16+
runtime: Rt<'a>,
17+
sink: Pin<Box<binary_copy::BinaryCopyInWriter>>,
18+
}
19+
20+
impl<'a> BinaryCopyInWriter<'a> {
21+
/// Creates a new writer which will write rows of the provided types.
22+
pub fn new(writer: CopyInWriter<'a>, types: &[Type]) -> BinaryCopyInWriter<'a> {
23+
let stream = writer
24+
.sink
25+
.into_unpinned()
26+
.expect("writer has already been written to");
27+
28+
BinaryCopyInWriter {
29+
runtime: writer.runtime,
30+
sink: Box::pin(binary_copy::BinaryCopyInWriter::new(stream, types)),
31+
}
32+
}
33+
34+
/// Writes a single row.
35+
///
36+
/// # Panics
37+
///
38+
/// Panics if the number of values provided does not match the number expected.
39+
pub fn write(&mut self, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
40+
self.runtime.block_on(self.sink.as_mut().write(values))
41+
}
42+
43+
/// A maximally-flexible version of `write`.
44+
///
45+
/// # Panics
46+
///
47+
/// Panics if the number of values provided does not match the number expected.
48+
pub fn write_raw<'b, I>(&mut self, values: I) -> Result<(), Error>
49+
where
50+
I: IntoIterator<Item = &'b dyn ToSql>,
51+
I::IntoIter: ExactSizeIterator,
52+
{
53+
self.runtime.block_on(self.sink.as_mut().write_raw(values))
54+
}
55+
56+
/// Completes the copy, returning the number of rows added.
57+
///
58+
/// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
59+
pub fn finish(mut self) -> Result<u64, Error> {
60+
self.runtime.block_on(self.sink.as_mut().finish())
61+
}
62+
}
63+
64+
/// An iterator of rows deserialized from the PostgreSQL binary copy format.
65+
pub struct BinaryCopyOutIter<'a> {
66+
runtime: Rt<'a>,
67+
stream: Pin<Box<BinaryCopyOutStream>>,
68+
}
69+
70+
impl<'a> BinaryCopyOutIter<'a> {
71+
/// Creates a new iterator from a raw copy out reader and the types of the columns being returned.
72+
pub fn new(reader: CopyOutReader<'a>, types: &[Type]) -> BinaryCopyOutIter<'a> {
73+
let stream = reader
74+
.stream
75+
.into_unpinned()
76+
.expect("reader has already been read from");
77+
78+
BinaryCopyOutIter {
79+
runtime: reader.runtime,
80+
stream: Box::pin(BinaryCopyOutStream::new(stream, types)),
81+
}
82+
}
83+
}
84+
85+
impl FallibleIterator for BinaryCopyOutIter<'_> {
86+
type Item = BinaryCopyOutRow;
87+
type Error = Error;
88+
89+
fn next(&mut self) -> Result<Option<BinaryCopyOutRow>, Error> {
90+
self.runtime.block_on(self.stream.next()).transpose()
91+
}
92+
}

postgres/src/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ impl Client {
382382
T: ?Sized + ToStatement,
383383
{
384384
let stream = self.runtime.block_on(self.client.copy_out(query))?;
385-
CopyOutReader::new(self.rt(), stream)
385+
Ok(CopyOutReader::new(self.rt(), stream))
386386
}
387387

388388
/// Executes a sequence of SQL statements using the simple query protocol.

postgres/src/copy_in_writer.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
1+
use crate::lazy_pin::LazyPin;
12
use crate::Rt;
23
use bytes::{Bytes, BytesMut};
34
use futures::SinkExt;
45
use std::io;
56
use std::io::Write;
6-
use std::pin::Pin;
77
use tokio_postgres::{CopyInSink, Error};
88

99
/// The writer returned by the `copy_in` method.
1010
///
1111
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
1212
pub struct CopyInWriter<'a> {
13-
runtime: Rt<'a>,
14-
sink: Pin<Box<CopyInSink<Bytes>>>,
13+
pub(crate) runtime: Rt<'a>,
14+
pub(crate) sink: LazyPin<CopyInSink<Bytes>>,
1515
buf: BytesMut,
1616
}
1717

1818
impl<'a> CopyInWriter<'a> {
1919
pub(crate) fn new(runtime: Rt<'a>, sink: CopyInSink<Bytes>) -> CopyInWriter<'a> {
2020
CopyInWriter {
2121
runtime,
22-
sink: Box::pin(sink),
22+
sink: LazyPin::new(sink),
2323
buf: BytesMut::new(),
2424
}
2525
}
@@ -29,7 +29,7 @@ impl<'a> CopyInWriter<'a> {
2929
/// If this is not called, the copy will be aborted.
3030
pub fn finish(mut self) -> Result<u64, Error> {
3131
self.flush_inner()?;
32-
self.runtime.block_on(self.sink.as_mut().finish())
32+
self.runtime.block_on(self.sink.pinned().finish())
3333
}
3434

3535
fn flush_inner(&mut self) -> Result<(), Error> {
@@ -38,7 +38,7 @@ impl<'a> CopyInWriter<'a> {
3838
}
3939

4040
self.runtime
41-
.block_on(self.sink.as_mut().send(self.buf.split().freeze()))
41+
.block_on(self.sink.pinned().send(self.buf.split().freeze()))
4242
}
4343
}
4444

postgres/src/copy_out_reader.rs

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,24 @@
1+
use crate::lazy_pin::LazyPin;
12
use crate::Rt;
23
use bytes::{Buf, Bytes};
34
use futures::StreamExt;
4-
use std::io::{self, BufRead, Cursor, Read};
5-
use std::pin::Pin;
6-
use tokio_postgres::{CopyOutStream, Error};
5+
use std::io::{self, BufRead, Read};
6+
use tokio_postgres::CopyOutStream;
77

88
/// The reader returned by the `copy_out` method.
99
pub struct CopyOutReader<'a> {
10-
runtime: Rt<'a>,
11-
stream: Pin<Box<CopyOutStream>>,
12-
cur: Cursor<Bytes>,
10+
pub(crate) runtime: Rt<'a>,
11+
pub(crate) stream: LazyPin<CopyOutStream>,
12+
cur: Bytes,
1313
}
1414

1515
impl<'a> CopyOutReader<'a> {
16-
pub(crate) fn new(
17-
mut runtime: Rt<'a>,
18-
stream: CopyOutStream,
19-
) -> Result<CopyOutReader<'a>, Error> {
20-
let mut stream = Box::pin(stream);
21-
let cur = match runtime.block_on(stream.next()) {
22-
Some(Ok(cur)) => cur,
23-
Some(Err(e)) => return Err(e),
24-
None => Bytes::new(),
25-
};
26-
27-
Ok(CopyOutReader {
16+
pub(crate) fn new(runtime: Rt<'a>, stream: CopyOutStream) -> CopyOutReader<'a> {
17+
CopyOutReader {
2818
runtime,
29-
stream,
30-
cur: Cursor::new(cur),
31-
})
19+
stream: LazyPin::new(stream),
20+
cur: Bytes::new(),
21+
}
3222
}
3323
}
3424

@@ -44,15 +34,15 @@ impl Read for CopyOutReader<'_> {
4434

4535
impl BufRead for CopyOutReader<'_> {
4636
fn fill_buf(&mut self) -> io::Result<&[u8]> {
47-
if self.cur.remaining() == 0 {
48-
match self.runtime.block_on(self.stream.next()) {
49-
Some(Ok(cur)) => self.cur = Cursor::new(cur),
37+
if !self.cur.has_remaining() {
38+
match self.runtime.block_on(self.stream.pinned().next()) {
39+
Some(Ok(cur)) => self.cur = cur,
5040
Some(Err(e)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
5141
None => {}
5242
};
5343
}
5444

55-
Ok(Buf::bytes(&self.cur))
45+
Ok(self.cur.bytes())
5646
}
5747

5848
fn consume(&mut self, amt: usize) {

postgres/src/lazy_pin.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use std::pin::Pin;
2+
3+
pub(crate) struct LazyPin<T> {
4+
value: Box<T>,
5+
pinned: bool,
6+
}
7+
8+
impl<T> LazyPin<T> {
9+
pub fn new(value: T) -> LazyPin<T> {
10+
LazyPin {
11+
value: Box::new(value),
12+
pinned: false,
13+
}
14+
}
15+
16+
pub fn pinned(&mut self) -> Pin<&mut T> {
17+
self.pinned = true;
18+
unsafe { Pin::new_unchecked(&mut *self.value) }
19+
}
20+
21+
pub fn into_unpinned(self) -> Option<T> {
22+
if self.pinned {
23+
None
24+
} else {
25+
Some(*self.value)
26+
}
27+
}
28+
}

postgres/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ mod client;
7676
pub mod config;
7777
mod copy_in_writer;
7878
mod copy_out_reader;
79+
mod lazy_pin;
80+
pub mod binary_copy;
7981
mod row_iter;
8082
mod transaction;
8183

postgres/src/test.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use tokio_postgres::types::Type;
33
use tokio_postgres::NoTls;
44

55
use super::*;
6+
use crate::binary_copy::{BinaryCopyInWriter, BinaryCopyOutIter};
7+
use fallible_iterator::FallibleIterator;
68

79
#[test]
810
fn prepare() {
@@ -188,6 +190,31 @@ fn copy_in_abort() {
188190
assert_eq!(rows.len(), 0);
189191
}
190192

193+
#[test]
194+
fn binary_copy_in() {
195+
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
196+
197+
client
198+
.simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
199+
.unwrap();
200+
201+
let writer = client.copy_in("COPY foo FROM stdin BINARY").unwrap();
202+
let mut writer = BinaryCopyInWriter::new(writer, &[Type::INT4, Type::TEXT]);
203+
writer.write(&[&1i32, &"steven"]).unwrap();
204+
writer.write(&[&2i32, &"timothy"]).unwrap();
205+
writer.finish().unwrap();
206+
207+
let rows = client
208+
.query("SELECT id, name FROM foo ORDER BY id", &[])
209+
.unwrap();
210+
211+
assert_eq!(rows.len(), 2);
212+
assert_eq!(rows[0].get::<_, i32>(0), 1);
213+
assert_eq!(rows[0].get::<_, &str>(1), "steven");
214+
assert_eq!(rows[1].get::<_, i32>(0), 2);
215+
assert_eq!(rows[1].get::<_, &str>(1), "timothy");
216+
}
217+
191218
#[test]
192219
fn copy_out() {
193220
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
@@ -209,6 +236,32 @@ fn copy_out() {
209236
client.simple_query("SELECT 1").unwrap();
210237
}
211238

239+
#[test]
240+
fn binary_copy_out() {
241+
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
242+
243+
client
244+
.simple_query(
245+
"CREATE TEMPORARY TABLE foo (id INT, name TEXT);
246+
INSERT INTO foo (id, name) VALUES (1, 'steven'), (2, 'timothy');",
247+
)
248+
.unwrap();
249+
250+
let reader = client
251+
.copy_out("COPY foo (id, name) TO STDOUT BINARY")
252+
.unwrap();
253+
let rows = BinaryCopyOutIter::new(reader, &[Type::INT4, Type::TEXT])
254+
.collect::<Vec<_>>()
255+
.unwrap();
256+
assert_eq!(rows.len(), 2);
257+
assert_eq!(rows[0].get::<i32>(0), 1);
258+
assert_eq!(rows[0].get::<&str>(1), "steven");
259+
assert_eq!(rows[1].get::<i32>(0), 2);
260+
assert_eq!(rows[1].get::<&str>(1), "timothy");
261+
262+
client.simple_query("SELECT 1").unwrap();
263+
}
264+
212265
#[test]
213266
fn portal() {
214267
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();

postgres/src/transaction.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ impl<'a> Transaction<'a> {
155155
T: ?Sized + ToStatement,
156156
{
157157
let stream = self.runtime.block_on(self.transaction.copy_out(query))?;
158-
CopyOutReader::new(self.rt(), stream)
158+
Ok(CopyOutReader::new(self.rt(), stream))
159159
}
160160

161161
/// Like `Client::simple_query`.

tokio-postgres/src/copy_out.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result<Copy
2222
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
2323
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
2424

25+
println!("a");
2526
match responses.next().await? {
2627
Message::BindComplete => {}
2728
_ => return Err(Error::unexpected_message()),
2829
}
2930

31+
println!("b");
3032
match responses.next().await? {
3133
Message::CopyOutResponse(_) => {}
3234
_ => return Err(Error::unexpected_message()),
@@ -50,6 +52,7 @@ impl Stream for CopyOutStream {
5052
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
5153
let this = self.project();
5254

55+
println!("c");
5356
match ready!(this.responses.poll_next(cx)?) {
5457
Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))),
5558
Message::CopyDone => Poll::Ready(None),

0 commit comments

Comments
 (0)