diff --git a/bindings/rust/src/connection.rs b/bindings/rust/src/connection.rs index 543f26c797..e87786fa0e 100644 --- a/bindings/rust/src/connection.rs +++ b/bindings/rust/src/connection.rs @@ -82,6 +82,30 @@ impl Connection { connection } + pub(crate) fn can_recycle_into_pool(&self) -> bool { + let Some(inner) = self.inner.as_ref() else { + return false; + }; + + if Arc::strong_count(inner) != 1 { + return false; + } + + if self.dangling_tx.load(Ordering::SeqCst) != DropBehavior::Ignore { + return false; + } + + self.get_inner_connection() + .map(|conn| conn.get_auto_commit()) + .unwrap_or(false) + } + + pub(crate) fn reset_for_reuse(&mut self) { + self.transaction_behavior = TransactionBehavior::Deferred; + self.dangling_tx + .store(DropBehavior::Ignore, Ordering::SeqCst); + } + pub(crate) async fn maybe_handle_dangling_tx(&self) -> Result<()> { match self.dangling_tx.load(Ordering::SeqCst) { DropBehavior::Rollback => { diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index de4b62186a..d667e9ed7a 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -54,8 +54,10 @@ pub use params::IntoParams; use std::fmt::Debug; use std::future::Future; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::sync::Mutex; +use std::sync::Weak; use std::task::Poll; // Re-exports rows @@ -133,6 +135,8 @@ pub(crate) type BoxError = Box; pub type Result = std::result::Result; pub type EncryptionOpts = turso_sdk_kit::rsapi::EncryptionOpts; +const DEFAULT_MAX_IDLE_CONNECTIONS: usize = 16; + /// A builder for `Database`. pub struct Builder { path: String, @@ -141,6 +145,7 @@ pub struct Builder { enable_custom_types: bool, enable_index_method: bool, enable_materialized_views: bool, + max_idle_connections: usize, vfs: Option, encryption_opts: Option, } @@ -155,6 +160,7 @@ impl Builder { enable_custom_types: false, enable_index_method: false, enable_materialized_views: false, + max_idle_connections: DEFAULT_MAX_IDLE_CONNECTIONS, vfs: None, encryption_opts: None, } @@ -204,6 +210,13 @@ impl Builder { self.vfs = Some(vfs); self } + + /// Maximum number of idle connections kept by `Database::connect_pooled`. + pub fn max_idle_connections(mut self, max_idle_connections: usize) -> Self { + self.max_idle_connections = max_idle_connections; + self + } + fn build_features_string(&self) -> Option { let mut features = Vec::new(); if self.enable_encryption { @@ -250,7 +263,10 @@ impl Builder { .await .map_err(TursoError::from)?; } - Ok(Database { inner: db }) + Ok(Database { + inner: db, + pool: Arc::new(ConnectionPool::new(self.max_idle_connections)), + }) } } @@ -260,6 +276,7 @@ impl Builder { #[derive(Clone)] pub struct Database { inner: Arc, + pool: Arc, } impl Debug for Database { @@ -274,6 +291,96 @@ impl Database { let conn = self.inner.connect()?; Ok(Connection::create(conn, None)) } + + /// Connect to the database using the built-in connection pool. + pub fn connect_pooled(&self) -> Result { + let conn = match self.pool.acquire() { + Some(conn) => conn, + None => self.connect()?, + }; + + Ok(PooledConnection { + conn: Some(conn), + pool: Arc::downgrade(&self.pool), + }) + } +} + +#[derive(Debug)] +struct ConnectionPool { + max_idle_connections: usize, + idle: Mutex>, +} + +impl ConnectionPool { + fn new(max_idle_connections: usize) -> Self { + Self { + max_idle_connections, + idle: Mutex::new(Vec::new()), + } + } + + fn acquire(&self) -> Option { + self.idle.lock().unwrap().pop() + } + + fn release(&self, mut conn: Connection) { + if self.max_idle_connections == 0 || !conn.can_recycle_into_pool() { + return; + } + + conn.reset_for_reuse(); + + let mut idle = self.idle.lock().unwrap(); + if idle.len() < self.max_idle_connections { + idle.push(conn); + } + } +} + +pub struct PooledConnection { + conn: Option, + pool: Weak, +} + +impl PooledConnection { + pub fn into_inner(mut self) -> Connection { + self.conn + .take() + .expect("pooled connection must always contain a connection") + } +} + +impl Deref for PooledConnection { + type Target = Connection; + + fn deref(&self) -> &Self::Target { + self.conn + .as_ref() + .expect("pooled connection must always contain a connection") + } +} + +impl DerefMut for PooledConnection { + fn deref_mut(&mut self) -> &mut Self::Target { + self.conn + .as_mut() + .expect("pooled connection must always contain a connection") + } +} + +impl Drop for PooledConnection { + fn drop(&mut self) { + let Some(conn) = self.conn.take() else { + return; + }; + + let Some(pool) = self.pool.upgrade() else { + return; + }; + + pool.release(conn); + } } /// A prepared statement. @@ -519,6 +626,12 @@ mod tests { use super::*; use tempfile::NamedTempFile; + async fn cache_size(conn: &Connection) -> Result { + let mut rows = conn.query("PRAGMA cache_size", ()).await?; + let row = rows.next().await?.expect("expected PRAGMA cache_size row"); + row.get(0) + } + #[tokio::test] async fn test_database_persistence() -> Result<()> { let temp_file = NamedTempFile::new().unwrap(); @@ -801,4 +914,146 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_connect_pooled_reuses_connection() -> Result<()> { + let db = Builder::new_local(":memory:").build().await?; + + { + let conn = db.connect_pooled()?; + conn.execute("PRAGMA cache_size = 1234", ()).await?; + } + + let conn = db.connect_pooled()?; + assert_eq!(cache_size(&conn).await?, 1234); + + Ok(()) + } + + #[tokio::test] + async fn test_connect_pooled_does_not_recycle_with_live_statement() -> Result<()> { + let db = Builder::new_local(":memory:").build().await?; + + let conn = db.connect_pooled()?; + conn.execute("PRAGMA cache_size = 3456", ()).await?; + let _stmt = conn.prepare("SELECT 1").await?; + drop(conn); + + let conn = db.connect_pooled()?; + assert_ne!(cache_size(&conn).await?, 3456); + + Ok(()) + } + + #[tokio::test] + async fn test_connect_pooled_respects_max_idle_connections() -> Result<()> { + let db = Builder::new_local(":memory:") + .max_idle_connections(0) + .build() + .await?; + + { + let conn = db.connect_pooled()?; + conn.execute("PRAGMA cache_size = 5678", ()).await?; + } + + let conn = db.connect_pooled()?; + assert_ne!(cache_size(&conn).await?, 5678); + + Ok(()) + } + + #[tokio::test] + async fn test_connect_pooled_capacity_drops_extra_idle_connections() -> Result<()> { + let db = Builder::new_local(":memory:") + .max_idle_connections(1) + .build() + .await?; + + let conn_a = db.connect_pooled()?; + conn_a.execute("PRAGMA cache_size = 12345", ()).await?; + + let conn_b = db.connect_pooled()?; + conn_b.execute("PRAGMA cache_size = 54321", ()).await?; + + // Capacity is 1. Drop A first so A is retained, then B should be dropped. + drop(conn_a); + drop(conn_b); + + let conn_from_pool = db.connect_pooled()?; + assert_eq!(cache_size(&conn_from_pool).await?, 12345); + + // With capacity=1 and first pooled conn still checked out, this must be fresh. + let fresh_conn = db.connect_pooled()?; + let fresh_cache = cache_size(&fresh_conn).await?; + assert_ne!(fresh_cache, 12345); + assert_ne!(fresh_cache, 54321); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn test_connect_pooled_concurrency_smoke() -> Result<()> { + let db = Builder::new_local(":memory:") + .max_idle_connections(8) + .build() + .await?; + + let task_count = 16; + let iterations = 50; + let barrier = std::sync::Arc::new(tokio::sync::Barrier::new(task_count)); + let mut handles = Vec::new(); + + for _ in 0..task_count { + let db = db.clone(); + let barrier = barrier.clone(); + handles.push(tokio::spawn(async move { + barrier.wait().await; + for _ in 0..iterations { + let conn = db.connect_pooled()?; + let mut rows = conn.query("SELECT 1", ()).await?; + let row = rows.next().await?.expect("expected row from SELECT 1"); + assert_eq!(row.get::(0)?, 1); + } + Ok::<(), Error>(()) + })); + } + + for handle in handles { + handle.await.expect("task panicked")?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_connect_pooled_into_inner_bypasses_pool() -> Result<()> { + let db = Builder::new_local(":memory:").build().await?; + + let pooled = db.connect_pooled()?; + pooled.execute("PRAGMA cache_size = 32123", ()).await?; + + let conn = pooled.into_inner(); + drop(conn); + + let conn = db.connect_pooled()?; + assert_ne!(cache_size(&conn).await?, 32123); + + Ok(()) + } + + #[tokio::test] + async fn test_connect_pooled_supports_mut_connection_methods() -> Result<()> { + use crate::transaction::TransactionBehavior; + + let db = Builder::new_local(":memory:").build().await?; + let mut conn = db.connect_pooled()?; + + conn.set_transaction_behavior(TransactionBehavior::Immediate); + let tx = conn.transaction().await?; + tx.rollback().await?; + assert!(conn.is_autocommit()?); + + Ok(()) + } } diff --git a/bindings/rust/src/sync.rs b/bindings/rust/src/sync.rs index 30ff1d2f35..9037a6f6e5 100644 --- a/bindings/rust/src/sync.rs +++ b/bindings/rust/src/sync.rs @@ -901,6 +901,14 @@ mod tests { Ok(result) } + fn is_retryable_parallel_write_error(err: &crate::Error) -> bool { + match err { + crate::Error::Busy(_) | crate::Error::BusySnapshot(_) => true, + crate::Error::Error(msg) => msg.to_ascii_lowercase().contains("schema changed"), + _ => false, + } + } + #[tokio::test] pub async fn test_sync_bootstrap() { let _ = tracing_subscriber::fmt::try_init(); @@ -1406,7 +1414,7 @@ mod tests { .await { Ok(_) => break, - Err(crate::Error::Busy(_)) => { + Err(e) if is_retryable_parallel_write_error(&e) => { tokio::time::sleep(Duration::from_millis(10)).await; continue; } @@ -1423,12 +1431,22 @@ mod tests { // Sequential writes: 3 more large inserts for i in 0..after_cnt { let data = format!("sequential_{i}_{payload}"); - conn.execute( - "INSERT INTO test_data (payload) VALUES (?)", - crate::params::Params::Positional(vec![Value::Text(data)]), - ) - .await - .unwrap(); + loop { + match conn + .execute( + "INSERT INTO test_data (payload) VALUES (?)", + crate::params::Params::Positional(vec![Value::Text(data.clone())]), + ) + .await + { + Ok(_) => break, + Err(e) if is_retryable_parallel_write_error(&e) => { + tokio::time::sleep(Duration::from_millis(10)).await; + continue; + } + Err(e) => panic!("sequential insert failed (row{i}): {e:?}"), + } + } } // Signal sync task to stop and wait for it diff --git a/bindings/rust/tests/integration_tests.rs b/bindings/rust/tests/integration_tests.rs index 6d8ca8c313..b2655e04ba 100644 --- a/bindings/rust/tests/integration_tests.rs +++ b/bindings/rust/tests/integration_tests.rs @@ -1,6 +1,25 @@ use tokio::fs; use turso::{Builder, EncryptionOpts, Error, Value}; +async fn cache_size(conn: &turso::Connection) -> i64 { + let mut rows = conn.query("PRAGMA cache_size", ()).await.unwrap(); + let row = rows.next().await.unwrap().expect("expected row"); + row.get(0).unwrap() +} + +#[tokio::test] +async fn test_connect_pooled_integration_reuses_connection() { + let db = Builder::new_local(":memory:").build().await.unwrap(); + + { + let conn = db.connect_pooled().unwrap(); + conn.execute("PRAGMA cache_size = 4321", ()).await.unwrap(); + } + + let conn = db.connect_pooled().unwrap(); + assert_eq!(cache_size(&conn).await, 4321); +} + #[tokio::test] async fn test_rows_next() { let builder = Builder::new_local(":memory:");