Skip to content

Commit 0310a7b

Browse files
committed
handle errors in initial postgres COPY statements
1 parent e7663f9 commit 0310a7b

File tree

2 files changed

+92
-22
lines changed

2 files changed

+92
-22
lines changed

sqlx-core/src/postgres/copy.rs

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,29 @@ pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
116116
}
117117

118118
impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
119-
async fn begin(mut conn: C, statement: &str) -> Result<Self> {
119+
async fn begin(conn: C, statement: &str) -> Result<Self> {
120+
let mut conn = Self::start_copy(conn, statement).await?;
121+
match conn.stream.recv_expect(MessageFormat::CopyInResponse).await {
122+
Ok(response) => Ok(PgCopyIn {
123+
conn: Some(conn),
124+
response,
125+
}),
126+
Err(e) => {
127+
conn.stream
128+
.send(CopyFail::new("failed to start COPY"))
129+
.await?;
130+
conn.stream
131+
.recv_expect(MessageFormat::ReadyForQuery)
132+
.await?;
133+
Err(e)
134+
}
135+
}
136+
}
137+
138+
async fn start_copy(mut conn: C, statement: &str) -> Result<C> {
120139
conn.wait_until_ready().await?;
121140
conn.stream.send(Query(statement)).await?;
122-
123-
let response: CopyResponse = conn
124-
.stream
125-
.recv_expect(MessageFormat::CopyInResponse)
126-
.await?;
127-
128-
Ok(PgCopyIn {
129-
conn: Some(conn),
130-
response,
131-
})
141+
Ok(conn)
132142
}
133143

134144
/// Returns `true` if Postgres is expecting data in text or CSV format.
@@ -252,18 +262,15 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
252262
"fail_with: expected ErrorResponse, got: {:?}",
253263
msg.format
254264
)),
255-
Err(Error::Database(e)) => {
256-
match e.code() {
257-
Some(Cow::Borrowed("57014")) => {
258-
// postgres abort received error code
259-
conn.stream
260-
.recv_expect(MessageFormat::ReadyForQuery)
261-
.await?;
262-
Ok(())
263-
}
264-
_ => Err(Error::Database(e)),
265+
Err(Error::Database(e)) => match e.code() {
266+
Some(Cow::Borrowed("57014")) => {
267+
conn.stream
268+
.recv_expect(MessageFormat::ReadyForQuery)
269+
.await?;
270+
Ok(())
265271
}
266-
}
272+
_ => Err(Error::Database(e)),
273+
},
267274
Err(e) => Err(e),
268275
}
269276
}

tests/postgres/postgres.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,3 +1743,66 @@ async fn test_postgres_bytea_hex_deserialization_errors() -> anyhow::Result<()>
17431743
}
17441744
Ok(())
17451745
}
1746+
1747+
async fn test_copy_in_error_case(query: &str, expected_error: &str) -> anyhow::Result<()> {
1748+
let mut conn = new::<Postgres>().await?;
1749+
conn.execute("CREATE TEMPORARY TABLE IF NOT EXISTS invalid_copy_target (id int4)")
1750+
.await?;
1751+
1752+
// Try the COPY operation
1753+
match conn.copy_in_raw(query).await {
1754+
Ok(_) => anyhow::bail!("expected error"),
1755+
Err(e) => assert!(
1756+
e.to_string().contains(expected_error),
1757+
"expected error to contain: {expected_error}, got: {e:?}"
1758+
),
1759+
}
1760+
1761+
// Verify connection is still usable
1762+
let value = sqlx_oldapi::query("select 1 + 1")
1763+
.try_map(|row: PgRow| row.try_get::<i32, _>(0))
1764+
.fetch_one(&mut conn)
1765+
.await?;
1766+
1767+
assert_eq!(2i32, value);
1768+
1769+
Ok(())
1770+
}
1771+
1772+
#[sqlx_macros::test]
1773+
async fn it_can_recover_from_copy_in_to_missing_table() -> anyhow::Result<()> {
1774+
test_copy_in_error_case(
1775+
r#"
1776+
COPY nonexistent_table (id) FROM STDIN WITH (FORMAT CSV, HEADER);
1777+
"#,
1778+
"does not exist",
1779+
)
1780+
.await
1781+
}
1782+
1783+
#[sqlx_macros::test]
1784+
async fn it_can_recover_from_copy_in_empty_query() -> anyhow::Result<()> {
1785+
test_copy_in_error_case("", "EmptyQuery").await
1786+
}
1787+
1788+
#[sqlx_macros::test]
1789+
async fn it_can_recover_from_copy_in_syntax_error() -> anyhow::Result<()> {
1790+
test_copy_in_error_case(
1791+
r#"
1792+
COPY FROM STDIN WITH (FORMAT CSV);
1793+
"#,
1794+
"syntax error",
1795+
)
1796+
.await
1797+
}
1798+
1799+
#[sqlx_macros::test]
1800+
async fn it_can_recover_from_copy_in_invalid_params() -> anyhow::Result<()> {
1801+
test_copy_in_error_case(
1802+
r#"
1803+
COPY invalid_copy_target FROM STDIN WITH (FORMAT CSV, INVALID_PARAM true);
1804+
"#,
1805+
"invalid_param",
1806+
)
1807+
.await
1808+
}

0 commit comments

Comments
 (0)