diff --git a/libsql/src/local/rows.rs b/libsql/src/local/rows.rs index 4d4e622c75..f76e15a003 100644 --- a/libsql/src/local/rows.rs +++ b/libsql/src/local/rows.rs @@ -46,8 +46,14 @@ impl Rows { err_msg = errors::error_from_handle(self.stmt.conn.raw); } match err { - libsql_sys::ffi::SQLITE_OK => Ok(None), - libsql_sys::ffi::SQLITE_DONE => Ok(None), + libsql_sys::ffi::SQLITE_OK => { + self.stmt.reset(); + Ok(None) + } + libsql_sys::ffi::SQLITE_DONE => { + self.stmt.reset(); + Ok(None) + } libsql_sys::ffi::SQLITE_ROW => Ok(Some(Row { stmt: self.stmt.clone(), })), @@ -82,6 +88,7 @@ impl AsRef for Rows { } } + pub struct RowsFuture { pub(crate) conn: Connection, pub(crate) sql: String, diff --git a/libsql/src/local/statement.rs b/libsql/src/local/statement.rs index c31e751734..d4eeb3c450 100644 --- a/libsql/src/local/statement.rs +++ b/libsql/src/local/statement.rs @@ -53,6 +53,7 @@ impl Statement { pub fn run(&self, params: &Params) -> Result<()> { self.bind(params); let err = self.inner.step(); + self.inner.reset(); match err { crate::ffi::SQLITE_DONE => Ok(()), crate::ffi::SQLITE_ROW => Ok(()), @@ -124,6 +125,7 @@ impl Statement { pub fn execute(&self, params: &Params) -> Result { self.bind(params); let err = self.inner.step(); + self.inner.reset(); match err { crate::ffi::SQLITE_DONE => Ok(self.conn.changes()), crate::ffi::SQLITE_ROW => Err(Error::ExecuteReturnedRows), diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index 57addab948..32bf95bd53 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -854,3 +854,36 @@ fn assert_sqlite_error(res: Result, code: i32) { } } } + +#[tokio::test] +async fn test_prepared_statement_reset() { + // Test for issue #2135 + let db = Database::open(":memory:").unwrap(); + let conn = db.connect().unwrap(); + + conn.execute("CREATE TABLE domain (name TEXT)", ()) + .await + .unwrap(); + + let stmt = conn + .prepare("INSERT INTO domain VALUES (?1)") + .await + .unwrap(); + + let domains = ["example.com", "example.org", "example.net"]; + for domain in domains { + stmt.execute([domain]).await.unwrap(); + } + + let mut rows = conn + .query("SELECT name FROM domain ORDER BY name", ()) + .await + .unwrap(); + let mut results = Vec::new(); + while let Some(row) = rows.next().await.unwrap() { + let name: String = row.get(0).unwrap(); + results.push(name); + } + + assert_eq!(results, vec!["example.com", "example.net", "example.org"]); +}