Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 97 additions & 2 deletions libsql-server/src/namespace/configurator/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,95 @@ async fn run_periodic_compactions(logger: Arc<ReplicationLogger>) -> anyhow::Res
}
}

fn tokenize_sql_keywords(text: &str) -> Vec<String> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the risk here is that the tokenizer has some subtle bugs. Did you consider the sqlite-parser we already use elsewhere in the server?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Firstly, I thought about the tradeoff of rewriting the load_dump function using existing parsers, like sqlite-parser, versus slightly modifying this function by introducing manual parsing with tokens but with potential pitfalls. I chose the latter. However, now I'm thinking that rewriting the initial function is not as bad as I imagined. I pushed the changes.

let mut tokens = Vec::new();
let mut chars = text.chars().peekable();
let mut current_token = String::new();
let mut in_string_literal = false;
let mut string_delimiter = '\0';

while let Some(ch) = chars.next() {
match ch {
'\'' | '"' => {
if !in_string_literal {
in_string_literal = true;
string_delimiter = ch;
} else if ch == string_delimiter {
in_string_literal = false;
}
}
c if c.is_whitespace() || "(){}[];,".contains(c) => {
if in_string_literal {
continue;
}
if !current_token.is_empty() {
tokens.push(current_token.to_uppercase());
current_token.clear();
}
}
// Regular characters
_ => {
if !in_string_literal {
current_token.push(ch);
}
}
}
}

if !current_token.is_empty() && !in_string_literal {
tokens.push(current_token.to_uppercase());
}

tokens
}

fn is_complete_sql_statement(sql: &str) -> bool {
let tokens = tokenize_sql_keywords(sql);
let mut begin_end_depth = 0;
let mut case_depth = 0;

for (i, token) in tokens.iter().enumerate() {
match token.as_str() {
"CASE" => {
case_depth += 1;
}
"BEGIN" => {
let next_token = tokens.get(i + 1).map(|s| s.as_str());
let is_transaction_keyword = matches!(
next_token,
Some("TRANSACTION") | Some("IMMEDIATE") | Some("EXCLUSIVE") | Some("DEFERRED")
);

if !is_transaction_keyword {
begin_end_depth += 1;
}
}
"END" => {
if case_depth > 0 {
case_depth -= 1;
} else {
// This is a block-ending END (BEGIN/END, IF/END IF, etc.)
let is_control_flow_end = tokens
.get(i + 1)
.map(|next| matches!(next.as_str(), "IF" | "LOOP" | "WHILE"))
.unwrap_or(false);

if !is_control_flow_end {
begin_end_depth -= 1;
}
}
}
_ => {}
}

if begin_end_depth < 0 {
return false;
}
}

begin_end_depth == 0 && case_depth == 0
}

async fn load_dump<S>(dump: S, conn: PrimaryConnection) -> crate::Result<(), LoadDumpError>
where
S: Stream<Item = std::io::Result<Bytes>> + Unpin,
Expand All @@ -311,12 +400,11 @@ where
curr.clear();
continue;
}
// FIXME: it's well known bug that comment ending with semicolon will be handled incorrectly by currend dump processing code
let statement_end = trimmed.ends_with(';');

// we want to concat original(non-trimmed) lines as trimming will join all them in one
// single-line statement which is incorrect if comments in the end are present
line.push_str(&curr);
let statement_end = trimmed.ends_with(';') && is_complete_sql_statement(&line);
curr.clear();

// This is a hack to ignore the libsql_wasm_func_table table because it is already created
Expand Down Expand Up @@ -374,6 +462,13 @@ where
}
tracing::debug!("loaded {} lines from dump", line_id);

if !line.trim().is_empty() {
return Err(LoadDumpError::InvalidSqlInput(format!(
"Incomplete SQL statement at end of dump: {}",
line.trim()
)));
}

if !conn.is_autocommit().await.unwrap() {
tokio::task::spawn_blocking({
let conn = conn.clone();
Expand Down
183 changes: 183 additions & 0 deletions libsql-server/tests/namespaces/dumps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,186 @@ fn load_dump_with_invalid_sql() {

sim.run().unwrap();
}

#[test]
fn load_dump_with_trigger() {
const DUMP: &str = r#"
BEGIN TRANSACTION;
CREATE TABLE test (x);
CREATE TRIGGER simple_trigger
AFTER INSERT ON test
BEGIN
INSERT INTO test VALUES (999);
END;
INSERT INTO test VALUES (1);
COMMIT;"#;

let mut sim = Builder::new()
.simulation_duration(Duration::from_secs(1000))
.build();
let tmp = tempdir().unwrap();
let tmp_path = tmp.path().to_path_buf();

std::fs::write(tmp_path.join("dump.sql"), DUMP).unwrap();

make_primary(&mut sim, tmp.path().to_path_buf());

sim.client("client", async move {
let client = Client::new();

let resp = client
.post(
"http://primary:9090/v1/namespaces/debug_test/create",
json!({ "dump_url": format!("file:{}", tmp_path.join("dump.sql").display())}),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);

let db = Database::open_remote_with_connector(
"http://debug_test.primary:8080",
"",
TurmoilConnector,
)?;
let conn = db.connect()?;

// Original INSERT: 1, Trigger INSERT: 999 = 2 total rows
let mut rows = conn.query("SELECT COUNT(*) FROM test", ()).await?;
let row = rows.next().await?.unwrap();
assert_eq!(row.get::<i64>(0)?, 2);

Ok(())
});

sim.run().unwrap();
}

#[test]
fn load_dump_with_case_trigger() {
const DUMP: &str = r#"
BEGIN TRANSACTION;
CREATE TABLE test (id INTEGER, rate REAL DEFAULT 0.0);
CREATE TRIGGER case_trigger
AFTER INSERT ON test
BEGIN
UPDATE test
SET rate =
CASE
WHEN NEW.id = 1
THEN 0.1
ELSE 0.0
END
WHERE id = NEW.id;
END;

INSERT INTO test (id) VALUES (1);
COMMIT;"#;

let mut sim = Builder::new()
.simulation_duration(Duration::from_secs(1000))
.build();
let tmp = tempdir().unwrap();
let tmp_path = tmp.path().to_path_buf();

std::fs::write(tmp_path.join("dump.sql"), DUMP).unwrap();

make_primary(&mut sim, tmp.path().to_path_buf());

sim.client("client", async move {
let client = Client::new();

let resp = client
.post(
"http://primary:9090/v1/namespaces/case_test/create",
json!({ "dump_url": format!("file:{}", tmp_path.join("dump.sql").display())}),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);

let db = Database::open_remote_with_connector(
"http://case_test.primary:8080",
"",
TurmoilConnector,
)?;
let conn = db.connect()?;

let mut rows = conn.query("SELECT id, rate FROM test", ()).await?;
let row = rows.next().await?.unwrap();
assert_eq!(row.get::<i64>(0)?, 1);
assert!((row.get::<f64>(1)? - 0.1).abs() < 0.001);

Ok(())
});

sim.run().unwrap();
}

#[test]
fn load_dump_with_nested_case() {
const DUMP: &str = r#"
BEGIN TRANSACTION;
CREATE TABLE orders (id INTEGER, amount REAL, status TEXT);
CREATE TRIGGER nested_trigger
AFTER UPDATE ON orders
BEGIN
UPDATE orders
SET amount =
CASE
WHEN NEW.status = 'completed'
THEN
CASE
WHEN OLD.id = 1
THEN OLD.amount * 0.9
ELSE OLD.amount * 0.8
END
ELSE OLD.amount
END
WHERE id = NEW.id;
END;

INSERT INTO orders (id, amount, status) VALUES (1, 100.0, 'pending');
COMMIT;"#;

let mut sim = Builder::new()
.simulation_duration(Duration::from_secs(1000))
.build();
let tmp = tempdir().unwrap();
let tmp_path = tmp.path().to_path_buf();

std::fs::write(tmp_path.join("dump.sql"), DUMP).unwrap();

make_primary(&mut sim, tmp.path().to_path_buf());

sim.client("client", async move {
let client = Client::new();

let resp = client
.post(
"http://primary:9090/v1/namespaces/nested_test/create",
json!({ "dump_url": format!("file:{}", tmp_path.join("dump.sql").display())}),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);

let db = Database::open_remote_with_connector(
"http://nested_test.primary:8080",
"",
TurmoilConnector,
)?;
let conn = db.connect()?;

conn.execute("UPDATE orders SET status = 'completed' WHERE id = 1", ())
.await?;
let mut rows = conn
.query("SELECT amount FROM orders WHERE id = 1", ())
.await?;
let row = rows.next().await?.unwrap();
assert!((row.get::<f64>(0)? - 90.0).abs() < 0.001);

Ok(())
});

sim.run().unwrap();
}
Loading