Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions bindings/rust/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,30 @@ impl Connection {
connection
}

pub(crate) fn can_recycle_into_pool(&self) -> bool {
let Some(inner) = self.inner.as_ref() else {
return false;
};

if Arc::strong_count(inner) != 1 {
return false;
}

if self.dangling_tx.load(Ordering::SeqCst) != DropBehavior::Ignore {
return false;
}

self.get_inner_connection()
.map(|conn| conn.get_auto_commit())
.unwrap_or(false)
}

pub(crate) fn reset_for_reuse(&mut self) {
self.transaction_behavior = TransactionBehavior::Deferred;
self.dangling_tx
.store(DropBehavior::Ignore, Ordering::SeqCst);
}

pub(crate) async fn maybe_handle_dangling_tx(&self) -> Result<()> {
match self.dangling_tx.load(Ordering::SeqCst) {
DropBehavior::Rollback => {
Expand Down
257 changes: 256 additions & 1 deletion bindings/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ pub use params::IntoParams;

use std::fmt::Debug;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::Weak;
use std::task::Poll;

// Re-exports rows
Expand Down Expand Up @@ -133,6 +135,8 @@ pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub type Result<T> = std::result::Result<T, Error>;
pub type EncryptionOpts = turso_sdk_kit::rsapi::EncryptionOpts;

const DEFAULT_MAX_IDLE_CONNECTIONS: usize = 16;

/// A builder for `Database`.
pub struct Builder {
path: String,
Expand All @@ -141,6 +145,7 @@ pub struct Builder {
enable_custom_types: bool,
enable_index_method: bool,
enable_materialized_views: bool,
max_idle_connections: usize,
vfs: Option<String>,
encryption_opts: Option<turso_sdk_kit::rsapi::EncryptionOpts>,
}
Expand All @@ -155,6 +160,7 @@ impl Builder {
enable_custom_types: false,
enable_index_method: false,
enable_materialized_views: false,
max_idle_connections: DEFAULT_MAX_IDLE_CONNECTIONS,
vfs: None,
encryption_opts: None,
}
Expand Down Expand Up @@ -204,6 +210,13 @@ impl Builder {
self.vfs = Some(vfs);
self
}

/// Maximum number of idle connections kept by `Database::connect_pooled`.
pub fn max_idle_connections(mut self, max_idle_connections: usize) -> Self {
self.max_idle_connections = max_idle_connections;
self
}

fn build_features_string(&self) -> Option<String> {
let mut features = Vec::new();
if self.enable_encryption {
Expand Down Expand Up @@ -250,7 +263,10 @@ impl Builder {
.await
.map_err(TursoError::from)?;
}
Ok(Database { inner: db })
Ok(Database {
inner: db,
pool: Arc::new(ConnectionPool::new(self.max_idle_connections)),
})
}
}

Expand All @@ -260,6 +276,7 @@ impl Builder {
#[derive(Clone)]
pub struct Database {
inner: Arc<turso_sdk_kit::rsapi::TursoDatabase>,
pool: Arc<ConnectionPool>,
}

impl Debug for Database {
Expand All @@ -274,6 +291,96 @@ impl Database {
let conn = self.inner.connect()?;
Ok(Connection::create(conn, None))
}

/// Connect to the database using the built-in connection pool.
pub fn connect_pooled(&self) -> Result<PooledConnection> {
let conn = match self.pool.acquire() {
Some(conn) => conn,
None => self.connect()?,
};

Ok(PooledConnection {
conn: Some(conn),
pool: Arc::downgrade(&self.pool),
})
}
}

#[derive(Debug)]
struct ConnectionPool {
max_idle_connections: usize,
idle: Mutex<Vec<Connection>>,
}

impl ConnectionPool {
fn new(max_idle_connections: usize) -> Self {
Self {
max_idle_connections,
idle: Mutex::new(Vec::new()),
}
}

fn acquire(&self) -> Option<Connection> {
self.idle.lock().unwrap().pop()
}

fn release(&self, mut conn: Connection) {
if self.max_idle_connections == 0 || !conn.can_recycle_into_pool() {
return;
}

conn.reset_for_reuse();

let mut idle = self.idle.lock().unwrap();
if idle.len() < self.max_idle_connections {
idle.push(conn);
}
}
}

pub struct PooledConnection {
conn: Option<Connection>,
pool: Weak<ConnectionPool>,
}

impl PooledConnection {
pub fn into_inner(mut self) -> Connection {
self.conn
.take()
.expect("pooled connection must always contain a connection")
}
}

impl Deref for PooledConnection {
type Target = Connection;

fn deref(&self) -> &Self::Target {
self.conn
.as_ref()
.expect("pooled connection must always contain a connection")
}
}

impl DerefMut for PooledConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn
.as_mut()
.expect("pooled connection must always contain a connection")
}
}

impl Drop for PooledConnection {
fn drop(&mut self) {
let Some(conn) = self.conn.take() else {
return;
};

let Some(pool) = self.pool.upgrade() else {
return;
};

pool.release(conn);
}
}

/// A prepared statement.
Expand Down Expand Up @@ -519,6 +626,12 @@ mod tests {
use super::*;
use tempfile::NamedTempFile;

async fn cache_size(conn: &Connection) -> Result<i64> {
let mut rows = conn.query("PRAGMA cache_size", ()).await?;
let row = rows.next().await?.expect("expected PRAGMA cache_size row");
row.get(0)
}

#[tokio::test]
async fn test_database_persistence() -> Result<()> {
let temp_file = NamedTempFile::new().unwrap();
Expand Down Expand Up @@ -801,4 +914,146 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_connect_pooled_reuses_connection() -> Result<()> {
let db = Builder::new_local(":memory:").build().await?;

{
let conn = db.connect_pooled()?;
conn.execute("PRAGMA cache_size = 1234", ()).await?;
}

let conn = db.connect_pooled()?;
assert_eq!(cache_size(&conn).await?, 1234);

Ok(())
}

#[tokio::test]
async fn test_connect_pooled_does_not_recycle_with_live_statement() -> Result<()> {
let db = Builder::new_local(":memory:").build().await?;

let conn = db.connect_pooled()?;
conn.execute("PRAGMA cache_size = 3456", ()).await?;
let _stmt = conn.prepare("SELECT 1").await?;
drop(conn);

let conn = db.connect_pooled()?;
assert_ne!(cache_size(&conn).await?, 3456);

Ok(())
}

#[tokio::test]
async fn test_connect_pooled_respects_max_idle_connections() -> Result<()> {
let db = Builder::new_local(":memory:")
.max_idle_connections(0)
.build()
.await?;

{
let conn = db.connect_pooled()?;
conn.execute("PRAGMA cache_size = 5678", ()).await?;
}

let conn = db.connect_pooled()?;
assert_ne!(cache_size(&conn).await?, 5678);

Ok(())
}

#[tokio::test]
async fn test_connect_pooled_capacity_drops_extra_idle_connections() -> Result<()> {
let db = Builder::new_local(":memory:")
.max_idle_connections(1)
.build()
.await?;

let conn_a = db.connect_pooled()?;
conn_a.execute("PRAGMA cache_size = 12345", ()).await?;

let conn_b = db.connect_pooled()?;
conn_b.execute("PRAGMA cache_size = 54321", ()).await?;

// Capacity is 1. Drop A first so A is retained, then B should be dropped.
drop(conn_a);
drop(conn_b);

let conn_from_pool = db.connect_pooled()?;
assert_eq!(cache_size(&conn_from_pool).await?, 12345);

// With capacity=1 and first pooled conn still checked out, this must be fresh.
let fresh_conn = db.connect_pooled()?;
let fresh_cache = cache_size(&fresh_conn).await?;
assert_ne!(fresh_cache, 12345);
assert_ne!(fresh_cache, 54321);

Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_connect_pooled_concurrency_smoke() -> Result<()> {
let db = Builder::new_local(":memory:")
.max_idle_connections(8)
.build()
.await?;

let task_count = 16;
let iterations = 50;
let barrier = std::sync::Arc::new(tokio::sync::Barrier::new(task_count));
let mut handles = Vec::new();

for _ in 0..task_count {
let db = db.clone();
let barrier = barrier.clone();
handles.push(tokio::spawn(async move {
barrier.wait().await;
for _ in 0..iterations {
let conn = db.connect_pooled()?;
let mut rows = conn.query("SELECT 1", ()).await?;
let row = rows.next().await?.expect("expected row from SELECT 1");
assert_eq!(row.get::<i64>(0)?, 1);
}
Ok::<(), Error>(())
}));
}

for handle in handles {
handle.await.expect("task panicked")?;
}

Ok(())
}

#[tokio::test]
async fn test_connect_pooled_into_inner_bypasses_pool() -> Result<()> {
let db = Builder::new_local(":memory:").build().await?;

let pooled = db.connect_pooled()?;
pooled.execute("PRAGMA cache_size = 32123", ()).await?;

let conn = pooled.into_inner();
drop(conn);

let conn = db.connect_pooled()?;
assert_ne!(cache_size(&conn).await?, 32123);

Ok(())
}

#[tokio::test]
async fn test_connect_pooled_supports_mut_connection_methods() -> Result<()> {
use crate::transaction::TransactionBehavior;

let db = Builder::new_local(":memory:").build().await?;
let mut conn = db.connect_pooled()?;

conn.set_transaction_behavior(TransactionBehavior::Immediate);
let tx = conn.transaction().await?;
tx.rollback().await?;
assert!(conn.is_autocommit()?);

Ok(())
}
}
19 changes: 19 additions & 0 deletions bindings/rust/tests/integration_tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
use tokio::fs;
use turso::{Builder, EncryptionOpts, Error, Value};

async fn cache_size(conn: &turso::Connection) -> i64 {
let mut rows = conn.query("PRAGMA cache_size", ()).await.unwrap();
let row = rows.next().await.unwrap().expect("expected row");
row.get(0).unwrap()
}

#[tokio::test]
async fn test_connect_pooled_integration_reuses_connection() {
let db = Builder::new_local(":memory:").build().await.unwrap();

{
let conn = db.connect_pooled().unwrap();
conn.execute("PRAGMA cache_size = 4321", ()).await.unwrap();
}

let conn = db.connect_pooled().unwrap();
assert_eq!(cache_size(&conn).await, 4321);
}

#[tokio::test]
async fn test_rows_next() {
let builder = Builder::new_local(":memory:");
Expand Down
Loading