Skip to content

Commit 6e2435e

Browse files
committed
Write full rows in binary copy
1 parent 8ebe859 commit 6e2435e

File tree

2 files changed

+64
-35
lines changed

2 files changed

+64
-35
lines changed

tokio-postgres-binary-copy/src/lib.rs

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ where
4242
let writer = BinaryCopyWriter {
4343
buf: buf.clone(),
4444
types: types.to_vec(),
45-
idx: 0,
4645
};
4746

4847
BinaryCopyStream {
@@ -85,14 +84,29 @@ where
8584
pub struct BinaryCopyWriter {
8685
buf: Arc<Mutex<BytesMut>>,
8786
types: Vec<Type>,
88-
idx: usize,
8987
}
9088

9189
impl BinaryCopyWriter {
9290
pub async fn write(
9391
&mut self,
94-
value: &(dyn ToSql + Send),
92+
values: &[&(dyn ToSql + Send)],
9593
) -> Result<(), Box<dyn Error + Sync + Send>> {
94+
self.write_raw(values.iter().cloned()).await
95+
}
96+
97+
pub async fn write_raw<'a, I>(&mut self, values: I) -> Result<(), Box<dyn Error + Sync + Send>>
98+
where
99+
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
100+
I::IntoIter: ExactSizeIterator,
101+
{
102+
let values = values.into_iter();
103+
assert!(
104+
values.len() == self.types.len(),
105+
"expected {} values but got {}",
106+
self.types.len(),
107+
values.len(),
108+
);
109+
96110
future::poll_fn(|_| {
97111
if self.buf.lock().len() > BLOCK_SIZE {
98112
Poll::Pending
@@ -103,20 +117,20 @@ impl BinaryCopyWriter {
103117
.await;
104118

105119
let mut buf = self.buf.lock();
106-
if self.idx == 0 {
107-
buf.reserve(2);
108-
buf.put_i16_be(self.types.len() as i16);
109-
}
110-
let idx = buf.len();
111-
buf.reserve(4);
112-
buf.put_i32_be(0);
113-
let len = match value.to_sql_checked(&self.types[self.idx], &mut buf)? {
114-
IsNull::Yes => -1,
115-
IsNull::No => i32::try_from(buf.len() - idx - 4)?,
116-
};
117-
BigEndian::write_i32(&mut buf[idx..], len);
118120

119-
self.idx = (self.idx + 1) % self.types.len();
121+
buf.reserve(2);
122+
buf.put_i16_be(self.types.len() as i16);
123+
124+
for (value, type_) in values.zip(&self.types) {
125+
let idx = buf.len();
126+
buf.reserve(4);
127+
buf.put_i32_be(0);
128+
let len = match value.to_sql_checked(type_, &mut buf)? {
129+
IsNull::Yes => -1,
130+
IsNull::No => i32::try_from(buf.len() - idx - 4)?,
131+
};
132+
BigEndian::write_i32(&mut buf[idx..], len);
133+
}
120134

121135
Ok(())
122136
}

tokio-postgres-binary-copy/src/test.rs

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,8 @@ async fn write_basic() {
2424

2525
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| {
2626
async move {
27-
w.write(&1i32).await?;
28-
w.write(&"foobar").await?;
29-
30-
w.write(&2i32).await?;
31-
w.write(&None::<&str>).await?;
27+
w.write(&[&1i32, &"foobar"]).await?;
28+
w.write(&[&2i32, &None::<&str>]).await?;
3229

3330
Ok(())
3431
}
@@ -39,7 +36,10 @@ async fn write_basic() {
3936
.await
4037
.unwrap();
4138

42-
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
39+
let rows = client
40+
.query("SELECT id, bar FROM foo ORDER BY id", &[])
41+
.await
42+
.unwrap();
4343
assert_eq!(rows.len(), 2);
4444
assert_eq!(rows[0].get::<_, i32>(0), 1);
4545
assert_eq!(rows[0].get::<_, Option<&str>>(1), Some("foobar"));
@@ -56,18 +56,25 @@ async fn write_many_rows() {
5656
.await
5757
.unwrap();
5858

59-
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| async move {
60-
for i in 0..10_000i32 {
61-
w.write(&i).await?;
62-
w.write(&format!("the value for {}", i)).await?;
63-
}
59+
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| {
60+
async move {
61+
for i in 0..10_000i32 {
62+
w.write(&[&i, &format!("the value for {}", i)]).await?;
63+
}
6464

65-
Ok(())
65+
Ok(())
66+
}
6667
});
6768

68-
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
69+
client
70+
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
71+
.await
72+
.unwrap();
6973

70-
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
74+
let rows = client
75+
.query("SELECT id, bar FROM foo ORDER BY id", &[])
76+
.await
77+
.unwrap();
7178
for (i, row) in rows.iter().enumerate() {
7279
assert_eq!(row.get::<_, i32>(0), i as i32);
7380
assert_eq!(row.get::<_, &str>(1), format!("the value for {}", i));
@@ -78,22 +85,30 @@ async fn write_many_rows() {
7885
async fn write_big_rows() {
7986
let client = connect().await;
8087

81-
client.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)").await.unwrap();
88+
client
89+
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
90+
.await
91+
.unwrap();
8292

8393
let stream = BinaryCopyStream::new(&[Type::INT4, Type::BYTEA], |mut w| {
8494
async move {
8595
for i in 0..2i32 {
86-
w.write(&i).await.unwrap();
87-
w.write(&vec![i as u8; 128 * 1024]).await.unwrap();
96+
w.write(&[&i, &vec![i as u8; 128 * 1024]]).await?;
8897
}
8998

9099
Ok(())
91100
}
92101
});
93102

94-
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
103+
client
104+
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
105+
.await
106+
.unwrap();
95107

96-
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
108+
let rows = client
109+
.query("SELECT id, bar FROM foo ORDER BY id", &[])
110+
.await
111+
.unwrap();
97112
for (i, row) in rows.iter().enumerate() {
98113
assert_eq!(row.get::<_, i32>(0), i as i32);
99114
assert_eq!(row.get::<_, &[u8]>(1), &*vec![i as u8; 128 * 1024]);

0 commit comments

Comments
 (0)