Skip to content

Commit 9b9b82a

Browse files
committed
Add copy_in to Statement
1 parent b5d9a38 commit 9b9b82a

File tree

2 files changed

+93
-1
lines changed

2 files changed

+93
-1
lines changed

src/lib.rs

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ use std::mem;
7272
use std::slice;
7373
use std::result;
7474
use std::vec;
75-
use byteorder::{WriteBytesExt, BigEndian};
7675
#[cfg(feature = "unix_socket")]
7776
use std::path::PathBuf;
7877

@@ -1598,6 +1597,86 @@ impl<'conn> Statement<'conn> {
15981597
})
15991598
}
16001599

1600+
/// Executes a `COPY FROM STDIN` statement, returning the number of rows
1601+
/// added.
1602+
///
1603+
/// The contents of the provided `Read`er are passed to the Postgres server
1604+
/// verbatim; it is the caller's responsibility to ensure the data is in
1605+
/// the proper format. See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html)
1606+
/// for details.
1607+
///
1608+
/// If the statement is not a `COPY FROM STDIN` statement, this method will
1609+
/// return an error though the statement will still be executed.
1610+
///
1611+
/// # Examples
1612+
///
1613+
/// ```rust,no_run
1614+
/// # use postgres::{Connection, SslMode};
1615+
/// # let conn = Connection::connect("", &SslMode::None).unwrap();
1616+
/// conn.batch_execute("CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR)").unwrap();
1617+
/// let stmt = conn.prepare("COPY people FROM STDIN").unwrap();
1618+
/// stmt.copy_in(&[], &mut "1\tjohn\n2\tjane\n".as_bytes()).unwrap();
1619+
/// ```
1620+
pub fn copy_in<R: Read>(&self, params: &[&ToSql], r: &mut R) -> Result<u64> {
1621+
try!(self.inner_execute("", 0, params));
1622+
let mut conn = self.conn.conn.borrow_mut();
1623+
1624+
match try!(conn.read_message()) {
1625+
CopyInResponse { .. } => {}
1626+
_ => {
1627+
loop {
1628+
match try!(conn.read_message()) {
1629+
ReadyForQuery { .. } => {
1630+
return Err(Error::IoError(std_io::Error::new(
1631+
std_io::ErrorKind::InvalidInput,
1632+
"called `copy_in` on a non-`COPY FROM STDIN` statement")));
1633+
}
1634+
_ => {}
1635+
}
1636+
}
1637+
}
1638+
}
1639+
1640+
let mut buf = vec![];
1641+
loop {
1642+
match std::io::copy(&mut r.take(16 * 1024), &mut buf) {
1643+
Ok(0) => break,
1644+
Ok(len) => {
1645+
try_desync!(conn, conn.stream.write_message(
1646+
&CopyData {
1647+
data: &buf[..len as usize],
1648+
}));
1649+
buf.clear();
1650+
}
1651+
Err(err) => {
1652+
// FIXME better to return the error directly
1653+
try_desync!(conn, conn.stream.write_message(
1654+
&CopyFail {
1655+
message: &err.to_string(),
1656+
}));
1657+
break;
1658+
}
1659+
}
1660+
}
1661+
1662+
try!(conn.write_messages(&[CopyDone, Sync]));
1663+
1664+
let num = match try!(conn.read_message()) {
1665+
CommandComplete { tag } => util::parse_update_count(tag),
1666+
ErrorResponse { fields } => {
1667+
try!(conn.wait_for_ready());
1668+
return DbError::new(fields);
1669+
}
1670+
_ => {
1671+
conn.desynchronized = true;
1672+
return Err(Error::IoError(bad_response()));
1673+
}
1674+
};
1675+
1676+
try!(conn.wait_for_ready());
1677+
Ok(num)
1678+
}
1679+
16011680
/// Consumes the statement, clearing it from the Postgres session.
16021681
///
16031682
/// If this statement was created via the `prepare_cached` method, `finish`

tests/test.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,19 @@ fn test_batch_execute_copy_from_err() {
765765
}
766766
}
767767

768+
#[test]
769+
fn test_copy() {
770+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
771+
or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[]));
772+
let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN"));
773+
let mut data = &b"1\n2\n3\n5\n8\n"[..];
774+
assert_eq!(5, or_panic!(stmt.copy_in(&[], &mut data)));
775+
let stmt = or_panic!(conn.prepare("SELECT id FROM foo ORDER BY id"));
776+
assert_eq!(vec![1i32, 2, 3, 5, 8],
777+
stmt.query(&[]).unwrap().iter().map(|r| r.get(0)).collect::<Vec<i32>>());
778+
}
779+
780+
768781
#[test]
769782
// Just make sure the impls don't infinite loop
770783
fn test_generic_connection() {

0 commit comments

Comments
 (0)