diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..ff9aa3bb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "tests/sqllogictest-sqlite"] + path = tests/sqllogictest-sqlite + url = https://github.com/risinglightdb/sqllogictest-sqlite diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index a0c24610..5e8c34ca 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -429,6 +429,10 @@ impl Runner { self.validator = validator; } + pub fn with_hash_threshold(&mut self, hash_threshold: usize) { + self.hash_threshold = hash_threshold; + } + pub async fn apply_record(&mut self, record: Record) -> RecordOutput { match record { Record::Statement { conditions, .. } if self.should_skip(&conditions) => { @@ -496,15 +500,30 @@ impl Runner { } }; + let mut value_sort = false; match sort_mode.as_ref().or(self.sort_mode.as_ref()) { None | Some(SortMode::NoSort) => {} Some(SortMode::RowSort) => { rows.sort_unstable(); } - Some(SortMode::ValueSort) => todo!("value sort"), + Some(SortMode::ValueSort) => { + rows = rows + .iter() + .flat_map(|row| row.iter()) + .map(|s| vec![s.to_owned()]) + .collect(); + rows.sort_unstable(); + value_sort = true; + } }; - if self.hash_threshold > 0 && rows.len() > self.hash_threshold { + let num_values = if value_sort { + rows.len() + } else { + rows.len() * types.len() + }; + + if self.hash_threshold > 0 && num_values > self.hash_threshold { let mut md5 = md5::Context::new(); for line in &rows { for value in line { @@ -688,17 +707,27 @@ impl Runner { } // We compare normalized results. Whitespace characters are ignored. - let normalized_rows = rows - .into_iter() - .map(|strs| strs.iter().map(normalize_string).join(" ")) - .collect_vec(); let expected_results = expected_results.iter().map(normalize_string).collect_vec(); - if !(self.validator)(&normalized_rows, &expected_results) { + + let actual_results = + if types.len() > 1 && rows.len() * types.len() == expected_results.len() { + // value-wise mode + rows.into_iter() + .flat_map(|strs| strs.iter().map(normalize_string).collect_vec()) + .collect_vec() + } else { + // row-wise mode + rows.into_iter() + .map(|strs| strs.iter().map(normalize_string).join(" ")) + .collect_vec() + }; + + if !(self.validator)(&actual_results, &expected_results) { return Err(TestErrorKind::QueryResultMismatch { sql, expected: expected_results.join("\n"), - actual: normalized_rows.join("\n"), + actual: actual_results.join("\n"), } .at(loc)); } diff --git a/tests/Cargo.toml b/tests/Cargo.toml index de111c48..3f9c63b9 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -6,8 +6,14 @@ publish = false [dependencies] sqllogictest = { path = "../sqllogictest" } +rusqlite = { version = "0.28", features = ["bundled"] } [[test]] name = "harness" path = "./harness.rs" -harness = false \ No newline at end of file +harness = false + +[[test]] +name = "sqlite" +path = "./sqlite.rs" +harness = false diff --git a/tests/sqlite.rs b/tests/sqlite.rs new file mode 100644 index 00000000..76b81529 --- /dev/null +++ b/tests/sqlite.rs @@ -0,0 +1,87 @@ +use rusqlite::{types::ValueRef, Connection, Error}; + +use sqllogictest::{harness, ColumnType, DBOutput, Runner, DB}; + +fn hash_threshold(filename: &str) -> usize { + match filename { + "sqlite/select1.test" => 8, + "sqlite/select4.test" => 8, + "sqlite/select5.test" => 8, + _ => 0, + } +} + +fn main() { + let paths = harness::glob("sqllogictest-sqlite/test/**/select*.test").unwrap(); + let mut tests = vec![]; + for entry in paths { + let path = entry.unwrap(); + let filename = path.to_str().unwrap().to_string(); + tests.push(harness::Trial::test(filename.clone(), move || { + let mut tester = Runner::new(db_fn()); + tester.with_hash_threshold(hash_threshold(&filename)); + tester.run_file(path)?; + Ok(()) + })); + } + harness::run(&harness::Arguments::from_args(), tests).exit(); +} + +struct ConnectionWrapper(Connection); + +fn db_fn() -> ConnectionWrapper { + let c = Connection::open_in_memory().unwrap(); + ConnectionWrapper(c) +} + +fn value_to_string(v: ValueRef) -> String { + match v { + ValueRef::Null => "NULL".to_string(), + ValueRef::Integer(i) => i.to_string(), + ValueRef::Real(r) => r.to_string(), + ValueRef::Text(s) => std::str::from_utf8(s).unwrap().to_string(), + ValueRef::Blob(_) => todo!(), + } +} + +impl DB for ConnectionWrapper { + type Error = Error; + + fn run(&mut self, sql: &str) -> Result { + let mut output = vec![]; + + let is_query_sql = { + let lower_sql = sql.trim_start().to_ascii_lowercase(); + lower_sql.starts_with("select") + || lower_sql.starts_with("values") + || lower_sql.starts_with("show") + || lower_sql.starts_with("with") + || lower_sql.starts_with("describe") + }; + + if is_query_sql { + let mut stmt = self.0.prepare(sql)?; + let column_count = stmt.column_count(); + let mut rows = stmt.query([])?; + while let Some(row) = rows.next()? { + let mut row_output = vec![]; + for i in 0..column_count { + let row = row.get_ref(i)?; + row_output.push(value_to_string(row)); + } + output.push(row_output); + } + Ok(DBOutput::Rows { + types: vec![ColumnType::Any; column_count], + rows: output, + }) + } else { + let cnt = self.0.execute(sql, [])?; + Ok(DBOutput::StatementComplete(cnt as u64)) + } + } + + fn engine_name(&self) -> &str { + "sqlite" + } +} diff --git a/tests/sqllogictest-sqlite b/tests/sqllogictest-sqlite new file mode 160000 index 00000000..4ab49f57 --- /dev/null +++ b/tests/sqllogictest-sqlite @@ -0,0 +1 @@ +Subproject commit 4ab49f571a02e3d21d6f8d4d2cc11eb90d7729c9