Skip to content

Commit f1c4d82

Browse files
committed
Write tests for the sqlite executor logic
Signed-off-by: Ryan Levick <[email protected]>
1 parent 69e93e5 commit f1c4d82

File tree

4 files changed

+144
-5
lines changed

4 files changed

+144
-5
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/factor-sqlite/src/lib.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,7 @@ impl Factor for SqliteFactor {
8181
get_connection_creator(label).is_some()
8282
})?;
8383

84-
Ok(AppState {
85-
allowed_databases,
86-
get_connection_creator,
87-
})
84+
Ok(AppState::new(allowed_databases, get_connection_creator))
8885
}
8986

9087
fn prepare<T: spin_factors::RuntimeFactors>(
@@ -158,6 +155,17 @@ pub struct AppState {
158155
}
159156

160157
impl AppState {
158+
/// Create a new `AppState`
159+
pub fn new(
160+
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
161+
get_connection_creator: host::ConnectionCreatorGetter,
162+
) -> Self {
163+
Self {
164+
allowed_databases,
165+
get_connection_creator,
166+
}
167+
}
168+
161169
/// Get a connection for a given database label.
162170
///
163171
/// Returns `None` if there is no connection creator for the given label.

crates/trigger/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,9 @@ terminal = { path = "../terminal" }
4545
tokio = { version = "1.23", features = ["fs", "rt"] }
4646
tracing = { workspace = true }
4747

48+
[dev-dependencies]
49+
spin-world = { path = "../world" }
50+
tempfile = "3.12"
51+
4852
[lints]
4953
workspace = true

crates/trigger/src/cli/sqlite_statements.rs

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,145 @@ where
7878
let Some(sqlite) = configured_app.app_state::<SqliteFactor>().ok() else {
7979
return Ok(());
8080
};
81-
self.execute(&sqlite).await?;
81+
self.execute(sqlite).await?;
8282
Ok(())
8383
}
8484
}
8585

8686
/// Parses a @{file:label} sqlite statement
8787
fn parse_file_and_label(config: &str) -> anyhow::Result<(&str, &str)> {
8888
let config = config.trim();
89+
if config.is_empty() {
90+
anyhow::bail!("database configuration is empty in the '@{config}' sqlite statement");
91+
}
8992
let (file, label) = match config.split_once(':') {
9093
Some((_, label)) if label.trim().is_empty() => {
9194
anyhow::bail!("database label is empty in the '@{config}' sqlite statement")
9295
}
96+
Some((file, _)) if file.trim().is_empty() => {
97+
anyhow::bail!("file path is empty in the '@{config}' sqlite statement")
98+
}
9399
Some((file, label)) => (file.trim(), label.trim()),
94100
None => (config, "default"),
95101
};
96102
Ok((file, label))
97103
}
104+
105+
#[cfg(test)]
106+
mod tests {
107+
use std::sync::Arc;
108+
use std::{collections::VecDeque, sync::mpsc::Sender};
109+
110+
use spin_core::async_trait;
111+
use spin_factor_sqlite::{Connection, ConnectionCreator};
112+
use spin_world::v2::sqlite as v2;
113+
use tempfile::NamedTempFile;
114+
115+
use super::*;
116+
117+
#[test]
118+
fn test_parse_file_and_label() {
119+
assert_eq!(
120+
parse_file_and_label("file:label").unwrap(),
121+
("file", "label")
122+
);
123+
assert!(parse_file_and_label("file:").is_err());
124+
assert_eq!(parse_file_and_label("file").unwrap(), ("file", "default"));
125+
assert!(parse_file_and_label(":label").is_err());
126+
assert!(parse_file_and_label("").is_err());
127+
}
128+
129+
#[tokio::test]
130+
async fn test_execute() {
131+
let sqlite_file = NamedTempFile::new().unwrap();
132+
std::fs::write(&sqlite_file, "select 2;").unwrap();
133+
134+
let hook = SqlStatementExecutorHook::new(vec![
135+
"SELECT 1;".to_string(),
136+
format!("@{path}:label", path = sqlite_file.path().display()),
137+
]);
138+
let (tx, rx) = std::sync::mpsc::channel();
139+
let creator = Arc::new(MockCreator { tx });
140+
let creator2 = creator.clone();
141+
let get_creator = Arc::new(move |label: &str| {
142+
creator.push(label);
143+
Some(creator2.clone() as _)
144+
});
145+
let sqlite = spin_factor_sqlite::AppState::new(Default::default(), get_creator);
146+
let result = hook.execute(&sqlite).await;
147+
assert!(result.is_ok());
148+
149+
let mut expected: VecDeque<Action> = vec![
150+
Action::CreateConnection("default".to_string()),
151+
Action::Query("SELECT 1;".to_string()),
152+
Action::CreateConnection("label".to_string()),
153+
Action::Execute("select 2;".to_string()),
154+
]
155+
.into_iter()
156+
.collect();
157+
while let Ok(action) = rx.try_recv() {
158+
assert_eq!(action, expected.pop_front().unwrap(), "unexpected action");
159+
}
160+
161+
assert!(
162+
expected.is_empty(),
163+
"Expected actions were never seen: {:?}",
164+
expected
165+
);
166+
}
167+
168+
struct MockCreator {
169+
tx: Sender<Action>,
170+
}
171+
172+
impl MockCreator {
173+
fn push(&self, label: &str) {
174+
self.tx
175+
.send(Action::CreateConnection(label.to_string()))
176+
.unwrap();
177+
}
178+
}
179+
180+
#[async_trait]
181+
impl ConnectionCreator for MockCreator {
182+
async fn create_connection(&self) -> Result<Box<dyn Connection + 'static>, v2::Error> {
183+
Ok(Box::new(MockConnection {
184+
tx: self.tx.clone(),
185+
}))
186+
}
187+
}
188+
189+
struct MockConnection {
190+
tx: Sender<Action>,
191+
}
192+
193+
#[async_trait]
194+
impl Connection for MockConnection {
195+
async fn query(
196+
&self,
197+
query: &str,
198+
parameters: Vec<v2::Value>,
199+
) -> Result<v2::QueryResult, v2::Error> {
200+
self.tx.send(Action::Query(query.to_string())).unwrap();
201+
let _ = parameters;
202+
Ok(v2::QueryResult {
203+
columns: Vec::new(),
204+
rows: Vec::new(),
205+
})
206+
}
207+
208+
async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
209+
self.tx
210+
.send(Action::Execute(statements.to_string()))
211+
.unwrap();
212+
Ok(())
213+
}
214+
}
215+
216+
#[derive(Debug, PartialEq)]
217+
enum Action {
218+
CreateConnection(String),
219+
Query(String),
220+
Execute(String),
221+
}
222+
}

0 commit comments

Comments
 (0)