diff --git a/libsql-server/src/namespace/configurator/helpers.rs b/libsql-server/src/namespace/configurator/helpers.rs index 810b84b76d..599320783d 100644 --- a/libsql-server/src/namespace/configurator/helpers.rs +++ b/libsql-server/src/namespace/configurator/helpers.rs @@ -7,10 +7,13 @@ use anyhow::Context as _; use bottomless::replicator::Options; use bytes::Bytes; use enclose::enclose; +use fallible_iterator::FallibleIterator; use futures::Stream; use libsql_sys::EncryptionConfig; use rusqlite::hooks::{AuthAction, AuthContext, Authorization}; -use tokio::io::AsyncBufReadExt as _; +use sqlite3_parser::ast::{Cmd, Stmt}; +use sqlite3_parser::lexer::sql::{Parser, ParserError}; +use tokio::io::AsyncReadExt; use tokio::task::JoinSet; use tokio_util::io::StreamReader; @@ -33,9 +36,6 @@ use crate::{StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT use super::{BaseNamespaceConfig, PrimaryConfig}; -const WASM_TABLE_CREATE: &str = - "CREATE TABLE libsql_wasm_func_table (name text PRIMARY KEY, body text) WITHOUT ROWID;"; - #[tracing::instrument(skip_all)] pub(super) async fn make_primary_connection_maker( primary_config: &PrimaryConfig, @@ -295,84 +295,89 @@ where S: Stream> + Unpin, { let mut reader = tokio::io::BufReader::new(StreamReader::new(dump)); - let mut curr = String::new(); - let mut line = String::new(); + let mut dump_content = String::new(); + reader + .read_to_string(&mut dump_content) + .await + .map_err(|e| LoadDumpError::Internal(format!("Failed to read dump content: {}", e)))?; + + if dump_content.to_lowercase().contains("attach") { + return Err(LoadDumpError::InvalidSqlInput( + "attach statements are not allowed in dumps".to_string(), + )); + } + + let mut parser = Box::new(Parser::new(dump_content.as_bytes())); let mut skipped_wasm_table = false; let mut n_stmt = 0; - let mut line_id = 0; - while let Ok(n) = reader.read_line(&mut curr).await { - line_id += 1; - if n == 0 { - break; - } - let trimmed = curr.trim(); - if trimmed.is_empty() || trimmed.starts_with("--") { - curr.clear(); - continue; - } - // FIXME: it's well known bug that comment ending with semicolon will be handled incorrectly by currend dump processing code - let statement_end = trimmed.ends_with(';'); - - // we want to concat original(non-trimmed) lines as trimming will join all them in one - // single-line statement which is incorrect if comments in the end are present - line.push_str(&curr); - curr.clear(); - - // This is a hack to ignore the libsql_wasm_func_table table because it is already created - // by the system. - if !skipped_wasm_table && line.trim() == WASM_TABLE_CREATE { - skipped_wasm_table = true; - line.clear(); - continue; - } + loop { + match parser.next() { + Ok(Some(cmd)) => { + n_stmt += 1; + + if !skipped_wasm_table { + if let Cmd::Stmt(Stmt::CreateTable { tbl_name, .. }) = &cmd { + if tbl_name.name.0 == "libsql_wasm_func_table" { + skipped_wasm_table = true; + tracing::debug!("Skipping WASM table creation"); + continue; + } + } + } + + if n_stmt > 2 && conn.is_autocommit().await.unwrap() { + return Err(LoadDumpError::NoTxn); + } - if statement_end { - n_stmt += 1; - // dump must be performd within a txn - if n_stmt > 2 && conn.is_autocommit().await.unwrap() { - return Err(LoadDumpError::NoTxn); + let stmt_sql = cmd.to_string(); + tokio::task::spawn_blocking({ + let conn = conn.clone(); + move || -> crate::Result<(), LoadDumpError> { + conn.with_raw(|conn| { + conn.authorizer(Some(|auth: AuthContext<'_>| match auth.action { + AuthAction::Attach { filename: _ } => Authorization::Deny, + _ => Authorization::Allow, + })); + conn.execute(&stmt_sql, ()) + }) + .map_err(|e| match e { + rusqlite::Error::SqlInputError { + msg, sql, offset, .. + } => LoadDumpError::InvalidSqlInput(format!( + "msg: {}, sql: {}, offset: {}", + msg, sql, offset + )), + e => LoadDumpError::Internal(format!( + "statement: {}, error: {}", + n_stmt, e + )), + })?; + Ok(()) + } + }) + .await??; } + Ok(None) => break, + Err(e) => { + let error_msg = match e { + sqlite3_parser::lexer::sql::Error::ParserError( + ParserError::SyntaxError { token_type, found }, + Some((line, col)), + ) => { + let near_token = found.as_deref().unwrap_or(&token_type); + format!( + "syntax error near '{}' at line {}, column {}", + near_token, line, col + ) + } + _ => format!("parse error: {}", e), + }; - line = tokio::task::spawn_blocking({ - let conn = conn.clone(); - move || -> crate::Result { - conn.with_raw(|conn| { - conn.authorizer(Some(|auth: AuthContext<'_>| match auth.action { - AuthAction::Attach { filename: _ } => Authorization::Deny, - _ => Authorization::Allow, - })); - conn.execute(&line, ()) - }) - .map_err(|e| match e { - rusqlite::Error::SqlInputError { - msg, sql, offset, .. - } => { - let msg = if sql.to_lowercase().contains("attach") { - format!( - "attach statements are not allowed in dumps, msg: {}, sql: {}, offset: {}", - msg, - sql, - offset - ) - } else { - format!("msg: {}, sql: {}, offset: {}", msg, sql, offset) - }; - - LoadDumpError::InvalidSqlInput(msg) - } - e => LoadDumpError::Internal(format!("line: {}, error: {}", line_id, e)), - })?; - Ok(line) - } - }) - .await??; - line.clear(); - } else { - line.push(' '); + return Err(LoadDumpError::InvalidSqlInput(error_msg)); + } } } - tracing::debug!("loaded {} lines from dump", line_id); if !conn.is_autocommit().await.unwrap() { tokio::task::spawn_blocking({ diff --git a/libsql-server/tests/namespaces/dumps.rs b/libsql-server/tests/namespaces/dumps.rs index 1ef1870d12..859130f773 100644 --- a/libsql-server/tests/namespaces/dumps.rs +++ b/libsql-server/tests/namespaces/dumps.rs @@ -425,3 +425,186 @@ fn load_dump_with_invalid_sql() { sim.run().unwrap(); } + +#[test] +fn load_dump_with_trigger() { + const DUMP: &str = r#" + BEGIN TRANSACTION; + CREATE TABLE test (x); + CREATE TRIGGER simple_trigger + AFTER INSERT ON test + BEGIN + INSERT INTO test VALUES (999); + END; + INSERT INTO test VALUES (1); + COMMIT;"#; + + let mut sim = Builder::new() + .simulation_duration(Duration::from_secs(1000)) + .build(); + let tmp = tempdir().unwrap(); + let tmp_path = tmp.path().to_path_buf(); + + std::fs::write(tmp_path.join("dump.sql"), DUMP).unwrap(); + + make_primary(&mut sim, tmp.path().to_path_buf()); + + sim.client("client", async move { + let client = Client::new(); + + let resp = client + .post( + "http://primary:9090/v1/namespaces/debug_test/create", + json!({ "dump_url": format!("file:{}", tmp_path.join("dump.sql").display())}), + ) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let db = Database::open_remote_with_connector( + "http://debug_test.primary:8080", + "", + TurmoilConnector, + )?; + let conn = db.connect()?; + + // Original INSERT: 1, Trigger INSERT: 999 = 2 total rows + let mut rows = conn.query("SELECT COUNT(*) FROM test", ()).await?; + let row = rows.next().await?.unwrap(); + assert_eq!(row.get::(0)?, 2); + + Ok(()) + }); + + sim.run().unwrap(); +} + +#[test] +fn load_dump_with_case_trigger() { + const DUMP: &str = r#" + BEGIN TRANSACTION; + CREATE TABLE test (id INTEGER, rate REAL DEFAULT 0.0); + CREATE TRIGGER case_trigger + AFTER INSERT ON test + BEGIN + UPDATE test + SET rate = + CASE + WHEN NEW.id = 1 + THEN 0.1 + ELSE 0.0 + END + WHERE id = NEW.id; + END; + + INSERT INTO test (id) VALUES (1); + COMMIT;"#; + + let mut sim = Builder::new() + .simulation_duration(Duration::from_secs(1000)) + .build(); + let tmp = tempdir().unwrap(); + let tmp_path = tmp.path().to_path_buf(); + + std::fs::write(tmp_path.join("dump.sql"), DUMP).unwrap(); + + make_primary(&mut sim, tmp.path().to_path_buf()); + + sim.client("client", async move { + let client = Client::new(); + + let resp = client + .post( + "http://primary:9090/v1/namespaces/case_test/create", + json!({ "dump_url": format!("file:{}", tmp_path.join("dump.sql").display())}), + ) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let db = Database::open_remote_with_connector( + "http://case_test.primary:8080", + "", + TurmoilConnector, + )?; + let conn = db.connect()?; + + let mut rows = conn.query("SELECT id, rate FROM test", ()).await?; + let row = rows.next().await?.unwrap(); + assert_eq!(row.get::(0)?, 1); + assert!((row.get::(1)? - 0.1).abs() < 0.001); + + Ok(()) + }); + + sim.run().unwrap(); +} + +#[test] +fn load_dump_with_nested_case() { + const DUMP: &str = r#" + BEGIN TRANSACTION; + CREATE TABLE orders (id INTEGER, amount REAL, status TEXT); + CREATE TRIGGER nested_trigger + AFTER UPDATE ON orders + BEGIN + UPDATE orders + SET amount = + CASE + WHEN NEW.status = 'completed' + THEN + CASE + WHEN OLD.id = 1 + THEN OLD.amount * 0.9 + ELSE OLD.amount * 0.8 + END + ELSE OLD.amount + END + WHERE id = NEW.id; + END; + + INSERT INTO orders (id, amount, status) VALUES (1, 100.0, 'pending'); + COMMIT;"#; + + let mut sim = Builder::new() + .simulation_duration(Duration::from_secs(1000)) + .build(); + let tmp = tempdir().unwrap(); + let tmp_path = tmp.path().to_path_buf(); + + std::fs::write(tmp_path.join("dump.sql"), DUMP).unwrap(); + + make_primary(&mut sim, tmp.path().to_path_buf()); + + sim.client("client", async move { + let client = Client::new(); + + let resp = client + .post( + "http://primary:9090/v1/namespaces/nested_test/create", + json!({ "dump_url": format!("file:{}", tmp_path.join("dump.sql").display())}), + ) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let db = Database::open_remote_with_connector( + "http://nested_test.primary:8080", + "", + TurmoilConnector, + )?; + let conn = db.connect()?; + + conn.execute("UPDATE orders SET status = 'completed' WHERE id = 1", ()) + .await?; + let mut rows = conn + .query("SELECT amount FROM orders WHERE id = 1", ()) + .await?; + let row = rows.next().await?.unwrap(); + assert!((row.get::(0)? - 90.0).abs() < 0.001); + + Ok(()) + }); + + sim.run().unwrap(); +} diff --git a/libsql-server/tests/namespaces/snapshots/tests__namespaces__dumps__load_dump_with_attach_rejected.snap b/libsql-server/tests/namespaces/snapshots/tests__namespaces__dumps__load_dump_with_attach_rejected.snap index 57809be5f1..f4cba37289 100644 --- a/libsql-server/tests/namespaces/snapshots/tests__namespaces__dumps__load_dump_with_attach_rejected.snap +++ b/libsql-server/tests/namespaces/snapshots/tests__namespaces__dumps__load_dump_with_attach_rejected.snap @@ -1,6 +1,5 @@ --- source: libsql-server/tests/namespaces/dumps.rs expression: resp.body_string().await? -snapshot_kind: text --- -{"error":"The passed dump sql is invalid: attach statements are not allowed in dumps, msg: near \"COMMIT\": syntax error, sql: ATTACH foo/bar.sql\n COMMIT;, offset: 28"} +{"error":"The passed dump sql is invalid: attach statements are not allowed in dumps"} diff --git a/libsql-server/tests/namespaces/snapshots/tests__namespaces__dumps__load_dump_with_invalid_sql.snap b/libsql-server/tests/namespaces/snapshots/tests__namespaces__dumps__load_dump_with_invalid_sql.snap index e1f600116f..2935c0eb69 100644 --- a/libsql-server/tests/namespaces/snapshots/tests__namespaces__dumps__load_dump_with_invalid_sql.snap +++ b/libsql-server/tests/namespaces/snapshots/tests__namespaces__dumps__load_dump_with_invalid_sql.snap @@ -3,4 +3,4 @@ source: libsql-server/tests/namespaces/dumps.rs expression: resp.body_string().await? snapshot_kind: text --- -{"error":"The passed dump sql is invalid: msg: near \"COMMIT\": syntax error, sql: SELECT abs(-9223372036854775808) \n COMMIT;, offset: 43"} +{"error":"The passed dump sql is invalid: syntax error near 'COMMIT' at line 7, column 11"}