Skip to content

Commit 913678e

Browse files
committed
Properly return ioerrors to the caller in copy_in
1 parent 53aafc3 commit 913678e

File tree

2 files changed

+41
-51
lines changed

2 files changed

+41
-51
lines changed

src/lib.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,18 +1581,27 @@ impl<'conn> Statement<'conn> {
15811581
Ok(0) => break,
15821582
Ok(len) => {
15831583
try_desync!(conn, conn.stream.write_message(
1584-
&CopyData {
1585-
data: &buf[..len as usize],
1586-
}));
1584+
&CopyData {
1585+
data: &buf[..len as usize],
1586+
}));
15871587
buf.clear();
15881588
}
15891589
Err(err) => {
1590-
// FIXME better to return the error directly
1591-
try_desync!(conn, conn.stream.write_message(
1592-
&CopyFail {
1593-
message: &err.to_string(),
1594-
}));
1595-
break;
1590+
try!(conn.write_messages(&[
1591+
CopyFail {
1592+
message: "",
1593+
},
1594+
CopyDone,
1595+
Sync]));
1596+
match try!(conn.read_message()) {
1597+
ErrorResponse { .. } => { /* expected from the CopyFail */ }
1598+
_ => {
1599+
conn.desynchronized = true;
1600+
return Err(Error::IoError(bad_response()));
1601+
}
1602+
}
1603+
try!(conn.wait_for_ready());
1604+
return Err(Error::IoError(err));
15961605
}
15971606
}
15981607
}

tests/test.rs

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ extern crate openssl;
77
#[cfg(feature = "openssl")]
88
use openssl::ssl::{SslContext, SslMethod};
99
use std::thread;
10+
use std::io;
1011

1112
use postgres::{HandleNotice,
1213
Notification,
@@ -605,48 +606,6 @@ fn test_notifications_next_block() {
605606
}, or_panic!(notifications.next_block()));
606607
}
607608

608-
/*
609-
#[test]
610-
fn test_notifications_next_block_for() {
611-
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
612-
or_panic!(conn.execute("LISTEN test_notifications_next_block_for", &[]));
613-
614-
let _t = thread::spawn(|| {
615-
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
616-
timer::sleep(Duration::milliseconds(500));
617-
or_panic!(conn.execute("NOTIFY test_notifications_next_block_for, 'foo'", &[]));
618-
});
619-
620-
let mut notifications = conn.notifications();
621-
check_notification(Notification {
622-
pid: 0,
623-
channel: "test_notifications_next_block_for".to_string(),
624-
payload: "foo".to_string()
625-
}, or_panic!(notifications.next_block_for(Duration::seconds(2)).unwrap()));
626-
}
627-
628-
#[test]
629-
fn test_notifications_next_block_for_timeout() {
630-
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
631-
or_panic!(conn.execute("LISTEN test_notifications_next_block_for_timeout", &[]));
632-
633-
let _t = thread::spawn(|| {
634-
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
635-
timer::sleep(Duration::seconds(2));
636-
or_panic!(conn.execute("NOTIFY test_notifications_next_block_for_timeout, 'foo'", &[]));
637-
});
638-
639-
let mut notifications = conn.notifications();
640-
match notifications.next_block_for(Duration::milliseconds(500)) {
641-
None => {}
642-
Some(Err(e)) => panic!("Unexpected error {:?}", e),
643-
Some(Ok(_)) => panic!("expected error"),
644-
}
645-
646-
or_panic!(conn.execute("SELECT 1", &[]));
647-
}
648-
*/
649-
650609
#[test]
651610
// This test is pretty sad, but I don't think there's a better way :(
652611
fn test_cancel_query() {
@@ -762,6 +721,28 @@ fn test_batch_execute_copy_from_err() {
762721
}
763722
}
764723

724+
#[test]
725+
fn test_copy_io_error() {
726+
struct ErrorReader;
727+
728+
impl io::Read for ErrorReader {
729+
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
730+
Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "boom"))
731+
}
732+
}
733+
734+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
735+
or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[]));
736+
let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN"));
737+
match stmt.copy_in(&[], &mut ErrorReader) {
738+
Err(Error::IoError(ref e)) if e.kind() == io::ErrorKind::AddrNotAvailable => {}
739+
Err(err) => panic!("Unexpected error {:?}", err),
740+
_ => panic!("Expected error"),
741+
}
742+
743+
or_panic!(conn.execute("SELECT 1", &[]));
744+
}
745+
765746
#[test]
766747
fn test_copy() {
767748
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));

0 commit comments

Comments
 (0)