Skip to content

Commit 03150f4

Browse files
committed
Overhaul the copy_out API
Returning a Reader ends up with a really weird user experience where you have to make sure to drop it before making any other calls and it has to internally fast forward to the end of the data even if the user drops it early. Simply taking a Writer that all data is pushed into is sigificantly more straightforward.
1 parent 5fe76e2 commit 03150f4

File tree

2 files changed

+74
-134
lines changed

2 files changed

+74
-134
lines changed

src/stmt.rs

Lines changed: 70 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use debug_builders::DebugStruct;
44
use std::cell::{Cell, RefMut};
55
use std::collections::VecDeque;
66
use std::fmt;
7-
use std::io::{self, Cursor, BufRead, Read};
7+
use std::io::{self, Read, Write};
88

99
use error::{Error, DbError};
1010
use types::{SessionInfo, Type, ToSql, IsNull};
@@ -379,37 +379,29 @@ impl<'conn> Statement<'conn> {
379379
Ok(num)
380380
}
381381

382-
/// Executes a `COPY TO STDOUT` statement, returning a `Read`er of the
383-
/// resulting data.
382+
/// Executes a `COPY TO STDOUT` statement, passing the resulting data to
383+
/// the provided writer and returning the number of rows received.
384384
///
385385
/// See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html)
386386
/// for details on the data format.
387387
///
388388
/// If the statement is not a `COPY TO STDOUT` statement it will still be
389389
/// executed and this method will return an error.
390390
///
391-
/// # Warning
392-
///
393-
/// The underlying connection may not be used while the returned `Read`er
394-
/// exists. Any attempt to do so will panic.
395-
///
396391
/// # Examples
397392
///
398393
/// ```rust,no_run
399-
/// # use std::io::Read;
400394
/// # use postgres::{Connection, SslMode};
401395
/// # let conn = Connection::connect("", &SslMode::None).unwrap();
402396
/// conn.batch_execute("
403397
/// CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR);
404398
/// INSERT INTO people (id, name) VALUES (1, 'john'), (2, 'jane');").unwrap();
405399
/// let stmt = conn.prepare("COPY people TO STDOUT").unwrap();
406-
/// let mut r = stmt.copy_out(&[]).unwrap();
407400
/// let mut buf = vec![];
408-
/// r.read_to_end(&mut buf).unwrap();
409-
/// r.finish().unwrap();
401+
/// let mut r = stmt.copy_out(&[], &mut buf).unwrap();
410402
/// assert_eq!(buf, b"1\tjohn\n2\tjane\n");
411403
/// ```
412-
pub fn copy_out<'a>(&'a self, params: &[&ToSql]) -> Result<CopyOutReader<'a>> {
404+
pub fn copy_out<'a, W: WriteWithInfo>(&'a self, params: &[&ToSql], w: &mut W) -> Result<u64> {
413405
try!(self.inner_execute("", 0, params));
414406
let mut conn = self.conn.conn.borrow_mut();
415407

@@ -448,15 +440,57 @@ impl<'conn> Statement<'conn> {
448440
}
449441
};
450442

451-
Ok(CopyOutReader {
452-
info: CopyInfo {
453-
conn: conn,
454-
format: Format::from_u16(format as u16),
455-
column_formats: column_formats.iter().map(|&f| Format::from_u16(f)).collect(),
456-
},
457-
buf: Cursor::new(vec![]),
458-
finished: false,
459-
})
443+
let mut info = CopyInfo {
444+
conn: conn,
445+
format: Format::from_u16(format as u16),
446+
column_formats: column_formats.iter().map(|&f| Format::from_u16(f)).collect(),
447+
};
448+
449+
let count;
450+
loop {
451+
match try!(info.conn.read_message()) {
452+
BCopyData { data } => {
453+
let mut data = &data[..];
454+
while !data.is_empty() {
455+
match w.write_with_info(data, &info) {
456+
Ok(n) => data = &data[n..],
457+
Err(e) => {
458+
loop {
459+
match try!(info.conn.read_message()) {
460+
ReadyForQuery { .. } => return Err(Error::IoError(e)),
461+
_ => {}
462+
}
463+
}
464+
}
465+
}
466+
}
467+
}
468+
BCopyDone => {},
469+
CommandComplete { tag } => {
470+
count = util::parse_update_count(tag);
471+
break;
472+
}
473+
ErrorResponse { fields } => {
474+
loop {
475+
match try!(info.conn.read_message()) {
476+
ReadyForQuery { .. } => return DbError::new(fields),
477+
_ => {}
478+
}
479+
}
480+
}
481+
_ => {
482+
loop {
483+
match try!(info.conn.read_message()) {
484+
ReadyForQuery { .. } => return Err(Error::IoError(bad_response())),
485+
_ => {}
486+
}
487+
}
488+
}
489+
}
490+
}
491+
492+
try!(info.conn.wait_for_ready());
493+
Ok(count)
460494
}
461495

462496
/// Consumes the statement, clearing it from the Postgres session.
@@ -539,6 +573,20 @@ impl<R: Read> ReadWithInfo for R {
539573
}
540574
}
541575

576+
/// Like `Write` except that a `CopyInfo` object is provided as well.
577+
///
578+
/// All types that implement `Write` also implement this trait.
579+
pub trait WriteWithInfo {
580+
/// Like `Write::write`.
581+
fn write_with_info(&mut self, buf: &[u8], info: &CopyInfo) -> io::Result<usize>;
582+
}
583+
584+
impl<W: Write> WriteWithInfo for W {
585+
fn write_with_info(&mut self, buf: &[u8], _: &CopyInfo) -> io::Result<usize> {
586+
self.write(buf)
587+
}
588+
}
589+
542590
impl Column {
543591
/// The name of the column.
544592
pub fn name(&self) -> &str {
@@ -568,95 +616,3 @@ impl Format {
568616
}
569617
}
570618
}
571-
572-
/// A `Read`er for data from `COPY TO STDOUT` queries.
573-
///
574-
/// # Warning
575-
///
576-
/// The underlying connection may not be used while a `CopyOutReader` exists.
577-
/// Any attempt to do so will panic.
578-
pub struct CopyOutReader<'a> {
579-
info: CopyInfo<'a>,
580-
buf: Cursor<Vec<u8>>,
581-
finished: bool,
582-
}
583-
584-
impl<'a> Drop for CopyOutReader<'a> {
585-
fn drop(&mut self) {
586-
let _ = self.finish_inner();
587-
}
588-
}
589-
590-
impl<'a> CopyOutReader<'a> {
591-
/// Returns the `CopyInfo` for the current operation.
592-
pub fn info(&self) -> &CopyInfo {
593-
&self.info
594-
}
595-
596-
/// Consumes the `CopyOutReader`, throwing away any unread data.
597-
///
598-
/// Functionally equivalent to `CopyOutReader`'s `Drop` implementation,
599-
/// except that it returns any error encountered to the caller.
600-
pub fn finish(mut self) -> Result<()> {
601-
self.finish_inner()
602-
}
603-
604-
fn finish_inner(&mut self) -> Result<()> {
605-
while !self.finished {
606-
let pos = self.buf.get_ref().len() as u64;
607-
self.buf.set_position(pos);
608-
try!(self.ensure_filled());
609-
}
610-
Ok(())
611-
}
612-
613-
fn ensure_filled(&mut self) -> Result<()> {
614-
if self.finished || self.buf.position() != self.buf.get_ref().len() as u64 {
615-
return Ok(());
616-
}
617-
618-
match try!(self.info.conn.read_message()) {
619-
BCopyData { data } => self.buf = Cursor::new(data),
620-
BCopyDone => {
621-
self.finished = true;
622-
match try!(self.info.conn.read_message()) {
623-
CommandComplete { .. } => {}
624-
_ => {
625-
self.info.conn.desynchronized = true;
626-
return Err(Error::IoError(bad_response()));
627-
}
628-
}
629-
try!(self.info.conn.wait_for_ready());
630-
}
631-
ErrorResponse { fields } => {
632-
self.finished = true;
633-
try!(self.info.conn.wait_for_ready());
634-
return DbError::new(fields);
635-
}
636-
_ => {
637-
self.info.conn.desynchronized = true;
638-
return Err(Error::IoError(bad_response()));
639-
}
640-
}
641-
642-
Ok(())
643-
}
644-
}
645-
646-
impl<'a> Read for CopyOutReader<'a> {
647-
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
648-
try!(self.ensure_filled());
649-
self.buf.read(buf)
650-
}
651-
}
652-
653-
impl<'a> BufRead for CopyOutReader<'a> {
654-
fn fill_buf(&mut self) -> io::Result<&[u8]> {
655-
try!(self.ensure_filled());
656-
self.buf.fill_buf()
657-
}
658-
659-
fn consume(&mut self, amt: usize) {
660-
self.buf.consume(amt)
661-
}
662-
}

tests/test.rs

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -778,26 +778,10 @@ fn test_copy_out() {
778778
CREATE TEMPORARY TABLE foo (id INT);
779779
INSERT INTO foo (id) VALUES (0), (1), (2), (3)"));
780780
let stmt = or_panic!(conn.prepare("COPY (SELECT id FROM foo ORDER BY id) TO STDOUT"));
781-
let mut reader = or_panic!(stmt.copy_out(&[]));
782-
let mut out = vec![];
783-
or_panic!(reader.read_to_end(&mut out));
784-
assert_eq!(out, b"0\n1\n2\n3\n");
785-
drop(reader);
786-
or_panic!(conn.batch_execute("SELECT 1"));
787-
}
788-
789-
#[test]
790-
fn test_copy_out_partial_read() {
791-
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
792-
or_panic!(conn.batch_execute("
793-
CREATE TEMPORARY TABLE foo (id INT);
794-
INSERT INTO foo (id) VALUES (0), (1), (2), (3)"));
795-
let stmt = or_panic!(conn.prepare("COPY (SELECT id FROM foo ORDER BY id) TO STDOUT"));
796-
let mut reader = or_panic!(stmt.copy_out(&[]));
797-
let mut out = vec![];
798-
or_panic!(reader.by_ref().take(5).read_to_end(&mut out));
799-
assert_eq!(out, b"0\n1\n2");
800-
drop(reader);
781+
let mut buf = vec![];
782+
let count = or_panic!(stmt.copy_out(&[], &mut buf));
783+
assert_eq!(count, 4);
784+
assert_eq!(buf, b"0\n1\n2\n3\n");
801785
or_panic!(conn.batch_execute("SELECT 1"));
802786
}
803787

0 commit comments

Comments
 (0)