Skip to content

Commit 750ad04

Browse files
committed
Rework how connection creation is done
Signed-off-by: Ryan Levick <[email protected]>
1 parent e236a7f commit 750ad04

File tree

5 files changed

+93
-98
lines changed

5 files changed

+93
-98
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/factor-sqlite/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ tracing = { workspace = true }
2121

2222
[dev-dependencies]
2323
spin-factors-test = { path = "../factors-test" }
24-
spin-sqlite = { path = "../sqlite" }
2524
tokio = { version = "1", features = ["macros", "rt"] }
2625

2726
[lints]

crates/factor-sqlite/src/host.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::collections::HashSet;
1+
use std::collections::{HashMap, HashSet};
22
use std::sync::Arc;
33

44
use async_trait::async_trait;
@@ -14,35 +14,28 @@ use crate::{Connection, ConnectionCreator};
1414

1515
pub struct InstanceState {
1616
allowed_databases: Arc<HashSet<String>>,
17+
/// A resource table of connections.
1718
connections: table::Table<Box<dyn Connection>>,
18-
get_connection_creator: ConnectionCreatorGetter,
19+
/// A map from database label to connection creators.
20+
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
1921
}
2022

21-
impl InstanceState {
22-
pub fn allowed_databases(&self) -> &HashSet<String> {
23-
&self.allowed_databases
24-
}
25-
}
26-
27-
/// A function that takes a database label and returns a connection creator, if one exists.
28-
pub type ConnectionCreatorGetter =
29-
Arc<dyn Fn(&str) -> Option<Arc<dyn ConnectionCreator>> + Send + Sync>;
30-
3123
impl InstanceState {
3224
/// Create a new `InstanceState`
3325
///
3426
/// Takes the list of allowed databases, and a function for getting a connection creator given a database label.
3527
pub fn new(
3628
allowed_databases: Arc<HashSet<String>>,
37-
get_connection_creator: ConnectionCreatorGetter,
29+
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
3830
) -> Self {
3931
Self {
4032
allowed_databases,
4133
connections: table::Table::new(256),
42-
get_connection_creator,
34+
connection_creators,
4335
}
4436
}
4537

38+
/// Get a connection for a given database label.
4639
fn get_connection(
4740
&self,
4841
connection: Resource<v2::Connection>,
@@ -52,6 +45,11 @@ impl InstanceState {
5245
.map(|conn| conn.as_ref())
5346
.ok_or(v2::Error::InvalidConnection)
5447
}
48+
49+
/// Get the set of allowed databases.
50+
pub fn allowed_databases(&self) -> &HashSet<String> {
51+
&self.allowed_databases
52+
}
5553
}
5654

5755
impl SelfInstanceBuilder for InstanceState {}
@@ -69,7 +67,9 @@ impl v2::HostConnection for InstanceState {
6967
if !self.allowed_databases.contains(&database) {
7068
return Err(v2::Error::AccessDenied);
7169
}
72-
let conn = (self.get_connection_creator)(&database)
70+
let conn = self
71+
.connection_creators
72+
.get(&database)
7373
.ok_or(v2::Error::NoSuchDatabase)?
7474
.create_connection(&database)
7575
.await?;

crates/factor-sqlite/src/lib.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,12 @@ impl Factor for SqliteFactor {
6464
))
6565
})
6666
.collect::<anyhow::Result<HashMap<_, _>>>()?;
67-
let get_connection_creator: host::ConnectionCreatorGetter =
68-
Arc::new(move |label| connection_creators.get(label).cloned());
6967

7068
ensure_allowed_databases_are_configured(&allowed_databases, |label| {
71-
get_connection_creator(label).is_some()
69+
connection_creators.contains_key(label)
7270
})?;
7371

74-
Ok(AppState::new(allowed_databases, get_connection_creator))
72+
Ok(AppState::new(allowed_databases, connection_creators))
7573
}
7674

7775
fn prepare<T: spin_factors::RuntimeFactors>(
@@ -84,10 +82,9 @@ impl Factor for SqliteFactor {
8482
.get(ctx.app_component().id())
8583
.cloned()
8684
.unwrap_or_default();
87-
let get_connection_creator = ctx.app_state().get_connection_creator.clone();
8885
Ok(InstanceState::new(
8986
allowed_databases,
90-
get_connection_creator,
87+
ctx.app_state().connection_creators.clone(),
9188
))
9289
}
9390
}
@@ -132,19 +129,19 @@ pub const ALLOWED_DATABASES_KEY: MetadataKey<Vec<String>> = MetadataKey::new("da
132129
pub struct AppState {
133130
/// A map from component id to a set of allowed database labels.
134131
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
135-
/// A function for mapping from database name to a connection creator.
136-
get_connection_creator: host::ConnectionCreatorGetter,
132+
/// A mapping from database label to a connection creator.
133+
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
137134
}
138135

139136
impl AppState {
140137
/// Create a new `AppState`
141138
pub fn new(
142139
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
143-
get_connection_creator: host::ConnectionCreatorGetter,
140+
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
144141
) -> Self {
145142
Self {
146143
allowed_databases,
147-
get_connection_creator,
144+
connection_creators,
148145
}
149146
}
150147

@@ -155,7 +152,9 @@ impl AppState {
155152
&self,
156153
label: &str,
157154
) -> Option<Result<Box<dyn Connection>, v2::Error>> {
158-
let connection = (self.get_connection_creator)(label)?
155+
let connection = self
156+
.connection_creators
157+
.get(label)?
159158
.create_connection(label)
160159
.await;
161160
Some(connection)
Lines changed: 68 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,132 +1,130 @@
1-
use std::{collections::HashSet, sync::Arc};
1+
use std::{
2+
collections::{HashMap, HashSet},
3+
sync::Arc,
4+
};
25

3-
use spin_factor_sqlite::SqliteFactor;
6+
use spin_factor_sqlite::{RuntimeConfig, SqliteFactor};
47
use spin_factors::{
5-
anyhow::{self, bail, Context},
6-
runtime_config::toml::TomlKeyTracker,
7-
Factor, FactorRuntimeConfigSource, RuntimeConfigSourceFinalizer, RuntimeFactors,
8+
anyhow::{self, bail, Context as _},
9+
RuntimeFactors,
810
};
911
use spin_factors_test::{toml, TestEnvironment};
10-
use spin_sqlite::RuntimeConfigResolver;
11-
use spin_world::async_trait;
12+
use spin_world::{async_trait, v2::sqlite as v2};
13+
use v2::HostConnection as _;
1214

1315
#[derive(RuntimeFactors)]
1416
struct TestFactors {
1517
sqlite: SqliteFactor,
1618
}
1719

1820
#[tokio::test]
19-
async fn sqlite_works() -> anyhow::Result<()> {
21+
async fn errors_when_non_configured_database_used() -> anyhow::Result<()> {
2022
let factors = TestFactors {
2123
sqlite: SqliteFactor::new(),
2224
};
2325
let env = TestEnvironment::new(factors).extend_manifest(toml! {
2426
[component.test-component]
2527
source = "does-not-exist.wasm"
26-
sqlite_databases = ["default"]
28+
sqlite_databases = ["foo"]
2729
});
28-
let state = env.build_instance_state().await?;
30+
let Err(err) = env.build_instance_state().await else {
31+
bail!("Expected build_instance_state to error but it did not");
32+
};
2933

30-
assert_eq!(
31-
state.sqlite.allowed_databases(),
32-
&["default".into()].into_iter().collect::<HashSet<_>>()
33-
);
34+
assert!(err
35+
.to_string()
36+
.contains("One or more components use SQLite databases which are not defined."));
3437

3538
Ok(())
3639
}
3740

3841
#[tokio::test]
39-
async fn errors_when_non_configured_database_used() -> anyhow::Result<()> {
42+
async fn errors_when_database_not_allowed() -> anyhow::Result<()> {
4043
let factors = TestFactors {
4144
sqlite: SqliteFactor::new(),
4245
};
4346
let env = TestEnvironment::new(factors).extend_manifest(toml! {
4447
[component.test-component]
4548
source = "does-not-exist.wasm"
46-
sqlite_databases = ["foo"]
49+
sqlite_databases = []
4750
});
48-
let Err(err) = env.build_instance_state().await else {
49-
bail!("Expected build_instance_state to error but it did not");
50-
};
51+
let mut state = env
52+
.build_instance_state()
53+
.await
54+
.context("build_instance_state failed")?;
5155

52-
assert!(err
53-
.to_string()
54-
.contains("One or more components use SQLite databases which are not defined."));
56+
assert!(matches!(
57+
state.sqlite.open("foo".into()).await,
58+
Err(spin_world::v2::sqlite::Error::AccessDenied)
59+
));
5560

5661
Ok(())
5762
}
5863

5964
#[tokio::test]
60-
async fn no_error_when_database_is_configured() -> anyhow::Result<()> {
65+
async fn it_works_when_database_is_configured() -> anyhow::Result<()> {
6166
let factors = TestFactors {
6267
sqlite: SqliteFactor::new(),
6368
};
64-
let runtime_config = toml! {
65-
[sqlite_database.foo]
66-
type = "spin"
69+
let mut connection_creators = HashMap::new();
70+
connection_creators.insert("foo".to_owned(), Arc::new(MockConnectionCreator) as _);
71+
let runtime_config = TestFactorsRuntimeConfig {
72+
sqlite: Some(RuntimeConfig {
73+
connection_creators,
74+
}),
6775
};
68-
let sqlite_config = RuntimeConfigResolver::new(None, "/".into());
6976
let env = TestEnvironment::new(factors)
7077
.extend_manifest(toml! {
7178
[component.test-component]
7279
source = "does-not-exist.wasm"
7380
sqlite_databases = ["foo"]
7481
})
75-
.runtime_config(TomlRuntimeSource::new(&runtime_config, sqlite_config))?;
76-
env.build_instance_state()
82+
.runtime_config(runtime_config)?;
83+
84+
let mut state = env
85+
.build_instance_state()
7786
.await
7887
.context("build_instance_state failed")?;
79-
Ok(())
80-
}
81-
82-
struct TomlRuntimeSource<'a> {
83-
table: TomlKeyTracker<'a>,
84-
runtime_config_resolver: RuntimeConfigResolver,
85-
}
86-
87-
impl<'a> TomlRuntimeSource<'a> {
88-
fn new(table: &'a toml::Table, runtime_config_resolver: RuntimeConfigResolver) -> Self {
89-
Self {
90-
table: TomlKeyTracker::new(table),
91-
runtime_config_resolver,
92-
}
93-
}
94-
}
9588

96-
impl FactorRuntimeConfigSource<SqliteFactor> for TomlRuntimeSource<'_> {
97-
fn get_runtime_config(
98-
&mut self,
99-
) -> anyhow::Result<Option<<SqliteFactor as Factor>::RuntimeConfig>> {
100-
self.runtime_config_resolver.resolve_from_toml(&self.table)
101-
}
102-
}
89+
assert_eq!(
90+
state.sqlite.allowed_databases(),
91+
&["foo".into()].into_iter().collect::<HashSet<_>>()
92+
);
10393

104-
impl RuntimeConfigSourceFinalizer for TomlRuntimeSource<'_> {
105-
fn finalize(&mut self) -> anyhow::Result<()> {
106-
self.table.validate_all_keys_used()?;
107-
Ok(())
108-
}
94+
assert!(state.sqlite.open("foo".into()).await.is_ok());
95+
Ok(())
10996
}
11097

111-
impl TryFrom<TomlRuntimeSource<'_>> for TestFactorsRuntimeConfig {
112-
type Error = anyhow::Error;
98+
/// A connection creator that returns a mock connection.
99+
struct MockConnectionCreator;
113100

114-
fn try_from(value: TomlRuntimeSource<'_>) -> Result<Self, Self::Error> {
115-
Self::from_source(value)
101+
#[async_trait]
102+
impl spin_factor_sqlite::ConnectionCreator for MockConnectionCreator {
103+
async fn create_connection(
104+
&self,
105+
label: &str,
106+
) -> Result<Box<dyn spin_factor_sqlite::Connection + 'static>, v2::Error> {
107+
let _ = label;
108+
Ok(Box::new(MockConnection))
116109
}
117110
}
118111

119-
/// A connection creator that always returns an error.
120-
struct InvalidConnectionCreator;
112+
/// A mock connection that always errors.
113+
struct MockConnection;
121114

122115
#[async_trait]
123-
impl spin_factor_sqlite::ConnectionCreator for InvalidConnectionCreator {
124-
async fn create_connection(
116+
impl spin_factor_sqlite::Connection for MockConnection {
117+
async fn query(
125118
&self,
126-
label: &str,
127-
) -> Result<Box<dyn spin_factor_sqlite::Connection + 'static>, spin_world::v2::sqlite::Error>
128-
{
129-
let _ = label;
130-
Err(spin_world::v2::sqlite::Error::InvalidConnection)
119+
query: &str,
120+
parameters: Vec<v2::Value>,
121+
) -> Result<v2::QueryResult, v2::Error> {
122+
let _ = (query, parameters);
123+
Err(v2::Error::Io("Mock connection".into()))
124+
}
125+
126+
async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
127+
let _ = statements;
128+
bail!("Mock connection")
131129
}
132130
}

0 commit comments

Comments
 (0)