Skip to content

Commit 6e99874

Browse files
committed
Add COPY TO STDOUT support.
Closes #51
1 parent 63e278b commit 6e99874

File tree

4 files changed

+237
-5
lines changed

4 files changed

+237
-5
lines changed

src/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,9 @@ impl From<byteorder::Error> for Error {
350350
Error::IoError(From::from(err))
351351
}
352352
}
353+
354+
impl From<Error> for io::Error {
355+
fn from(err: Error) -> io::Error {
356+
io::Error::new(io::ErrorKind::Other, err)
357+
}
358+
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ mod macros;
8989
mod md5;
9090
mod message;
9191
mod priv_io;
92-
mod stmt;
9392
mod url;
9493
mod util;
9594
pub mod error;
9695
pub mod io;
9796
pub mod rows;
97+
pub mod stmt;
9898
pub mod types;
9999

100100
const TYPEINFO_QUERY: &'static str = "t";

src/stmt.rs

Lines changed: 198 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
//! Prepared statements
2+
13
use debug_builders::DebugStruct;
2-
use std::cell::Cell;
4+
use std::cell::{Cell, RefMut};
35
use std::collections::VecDeque;
46
use std::fmt;
5-
use std::io;
7+
use std::io::{self, Cursor, BufRead, Read};
68

79
use error::{Error, DbError};
810
use types::{ReadWithInfo, SessionInfo, Type, ToSql, IsNull};
@@ -12,7 +14,7 @@ use message::WriteMessage;
1214
use util;
1315
use rows::{Rows, LazyRows};
1416
use {read_rows, bad_response, Connection, Transaction, StatementInternals, Result, RowsNew};
15-
use {SessionInfoNew, LazyRowsNew, DbErrorNew, ColumnNew};
17+
use {InnerConnection, SessionInfoNew, LazyRowsNew, DbErrorNew, ColumnNew};
1618

1719
/// A prepared statement.
1820
pub struct Statement<'conn> {
@@ -371,6 +373,84 @@ impl<'conn> Statement<'conn> {
371373
Ok(num)
372374
}
373375

376+
/// Executes a `COPY TO STDOUT` statement, returning a `Read`er of the
377+
/// resulting data.
378+
///
379+
/// See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html)
380+
/// for details on the data format.
381+
///
382+
/// If the statement is not a `COPY TO STDOUT` statement it will still be
383+
/// executed and this method will return an error.
384+
///
385+
/// # Warning
386+
///
387+
/// The underlying connection may not be used while the returned `Read`er
388+
/// exists. Any attempt to do so will panic.
389+
///
390+
/// # Examples
391+
///
392+
/// ```rust,no_run
393+
/// # use std::io::Read;
394+
/// # use postgres::{Connection, SslMode};
395+
/// # let conn = Connection::connect("", &SslMode::None).unwrap();
396+
/// conn.batch_execute("
397+
/// CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR);
398+
/// INSERT INTO people (id, name) VALUES (1, 'john'), (2, 'jane');").unwrap();
399+
/// let stmt = conn.prepare("COPY people TO STDOUT").unwrap();
400+
/// let mut r = stmt.copy_out(&[]).unwrap();
401+
/// let mut buf = vec![];
402+
/// r.read_to_end(&mut buf).unwrap();
403+
/// r.finish().unwrap();
404+
/// assert_eq!(buf, b"1\tjohn\n2\tjane\n");
405+
/// ```
406+
pub fn copy_out<'a>(&'a self, params: &[&ToSql]) -> Result<CopyOutReader<'a>> {
407+
try!(self.inner_execute("", 0, params));
408+
let mut conn = self.conn.conn.borrow_mut();
409+
410+
let (format, column_formats) = match try!(conn.read_message()) {
411+
CopyOutResponse { format, column_formats } => (format, column_formats),
412+
CopyInResponse { .. } => {
413+
try!(conn.write_messages(&[
414+
CopyFail {
415+
message: "",
416+
},
417+
CopyDone,
418+
Sync]));
419+
match try!(conn.read_message()) {
420+
ErrorResponse { .. } => { /* expected from the CopyFail */ }
421+
_ => {
422+
conn.desynchronized = true;
423+
return Err(Error::IoError(bad_response()));
424+
}
425+
}
426+
try!(conn.wait_for_ready());
427+
return Err(Error::IoError(io::Error::new(
428+
io::ErrorKind::InvalidInput,
429+
"called `copy_out` on a non-`COPY TO STDOUT` statement")));
430+
}
431+
_ => {
432+
loop {
433+
match try!(conn.read_message()) {
434+
ReadyForQuery { .. } => {
435+
return Err(Error::IoError(io::Error::new(
436+
io::ErrorKind::InvalidInput,
437+
"called `copy_out` on a non-`COPY TO STDOUT` statement")));
438+
}
439+
_ => {}
440+
}
441+
}
442+
}
443+
};
444+
445+
Ok(CopyOutReader {
446+
conn: conn,
447+
format: Format::from_u16(format as u16),
448+
column_formats: column_formats.iter().map(|&f| Format::from_u16(f)).collect(),
449+
buf: Cursor::new(vec![]),
450+
finished: false,
451+
})
452+
}
453+
374454
/// Consumes the statement, clearing it from the Postgres session.
375455
///
376456
/// If this statement was created via the `prepare_cached` method, `finish`
@@ -425,4 +505,119 @@ impl Column {
425505
}
426506
}
427507

508+
/// The format of a portion of COPY query data.
509+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
510+
pub enum Format {
511+
/// A text based format.
512+
Text,
513+
/// A binary format.
514+
Binary,
515+
}
516+
517+
impl Format {
518+
fn from_u16(value: u16) -> Format {
519+
match value {
520+
0 => Format::Text,
521+
_ => Format::Binary,
522+
}
523+
}
524+
}
525+
526+
/// A `Read`er for data from `COPY TO STDOUT` queries.
527+
///
528+
/// # Warning
529+
///
530+
/// The underlying connection may not be used while a `CopyOutReader` exists.
531+
/// Any calls to the connection with panic.
532+
pub struct CopyOutReader<'a> {
533+
conn: RefMut<'a, InnerConnection>,
534+
format: Format,
535+
column_formats: Vec<Format>,
536+
buf: Cursor<Vec<u8>>,
537+
finished: bool,
538+
}
539+
540+
impl<'a> Drop for CopyOutReader<'a> {
541+
fn drop(&mut self) {
542+
let _ = self.finish_inner();
543+
}
544+
}
545+
546+
impl<'a> CopyOutReader<'a> {
547+
/// Returns the format of the overall data.
548+
pub fn format(&self) -> Format {
549+
self.format
550+
}
551+
552+
/// Returns the format of the individual columns.
553+
pub fn column_formats(&self) -> &[Format] {
554+
&self.column_formats
555+
}
556+
557+
/// Consumes the `CopyOutReader`, throwing away any unread data.
558+
///
559+
/// Functionally equivalent to `CopyOutReader`'s `Drop` implementation,
560+
/// except that it returns any error encountered to the caller.
561+
pub fn finish(mut self) -> Result<()> {
562+
self.finish_inner()
563+
}
564+
565+
fn finish_inner(&mut self) -> Result<()> {
566+
while !self.finished {
567+
let pos = self.buf.get_ref().len() as u64;
568+
self.buf.set_position(pos);
569+
try!(self.ensure_filled());
570+
}
571+
Ok(())
572+
}
573+
574+
fn ensure_filled(&mut self) -> Result<()> {
575+
if self.finished || self.buf.position() != self.buf.get_ref().len() as u64 {
576+
return Ok(());
577+
}
428578

579+
match try!(self.conn.read_message()) {
580+
BCopyData { data } => self.buf = Cursor::new(data),
581+
BCopyDone => {
582+
self.finished = true;
583+
match try!(self.conn.read_message()) {
584+
CommandComplete { .. } => {}
585+
_ => {
586+
self.conn.desynchronized = true;
587+
return Err(Error::IoError(bad_response()));
588+
}
589+
}
590+
try!(self.conn.wait_for_ready());
591+
}
592+
ErrorResponse { fields } => {
593+
self.finished = true;
594+
try!(self.conn.wait_for_ready());
595+
return DbError::new(fields);
596+
}
597+
_ => {
598+
self.conn.desynchronized = true;
599+
return Err(Error::IoError(bad_response()));
600+
}
601+
}
602+
603+
Ok(())
604+
}
605+
}
606+
607+
impl<'a> Read for CopyOutReader<'a> {
608+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
609+
try!(self.ensure_filled());
610+
self.buf.read(buf)
611+
}
612+
}
613+
614+
impl<'a> BufRead for CopyOutReader<'a> {
615+
fn fill_buf(&mut self) -> io::Result<&[u8]> {
616+
try!(self.ensure_filled());
617+
self.buf.fill_buf()
618+
}
619+
620+
fn consume(&mut self, amt: usize) {
621+
self.buf.consume(amt)
622+
}
623+
}

tests/test.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ extern crate openssl;
88
use openssl::ssl::{SslContext, SslMethod};
99
use std::thread;
1010
use std::io;
11+
use std::io::prelude::*;
1112

1213
use postgres::{HandleNotice,
1314
Notification,
@@ -757,7 +758,7 @@ fn test_copy() {
757758
}
758759

759760
#[test]
760-
fn test_copy_out_query() {
761+
fn test_query_copy_out_err() {
761762
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
762763
or_panic!(conn.batch_execute("
763764
CREATE TEMPORARY TABLE foo (id INT);
@@ -770,6 +771,36 @@ fn test_copy_out_query() {
770771
}
771772
}
772773

774+
#[test]
775+
fn test_copy_out() {
776+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
777+
or_panic!(conn.batch_execute("
778+
CREATE TEMPORARY TABLE foo (id INT);
779+
INSERT INTO foo (id) VALUES (0), (1), (2), (3)"));
780+
let stmt = or_panic!(conn.prepare("COPY (SELECT id FROM foo ORDER BY id) TO STDOUT"));
781+
let mut reader = or_panic!(stmt.copy_out(&[]));
782+
let mut out = vec![];
783+
or_panic!(reader.read_to_end(&mut out));
784+
assert_eq!(out, b"0\n1\n2\n3\n");
785+
drop(reader);
786+
or_panic!(conn.batch_execute("SELECT 1"));
787+
}
788+
789+
#[test]
790+
fn test_copy_out_partial_read() {
791+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
792+
or_panic!(conn.batch_execute("
793+
CREATE TEMPORARY TABLE foo (id INT);
794+
INSERT INTO foo (id) VALUES (0), (1), (2), (3)"));
795+
let stmt = or_panic!(conn.prepare("COPY (SELECT id FROM foo ORDER BY id) TO STDOUT"));
796+
let mut reader = or_panic!(stmt.copy_out(&[]));
797+
let mut out = vec![];
798+
or_panic!(reader.by_ref().take(5).read_to_end(&mut out));
799+
assert_eq!(out, b"0\n1\n2");
800+
drop(reader);
801+
or_panic!(conn.batch_execute("SELECT 1"));
802+
}
803+
773804
#[test]
774805
// Just make sure the impls don't infinite loop
775806
fn test_generic_connection() {

0 commit comments

Comments
 (0)