Skip to content

Commit 21c66f4

Browse files
authored
Merge pull request #2840 from fermyon/simplify-default-sqlite
Simplify default database resolution in sqlite
2 parents caacf55 + 4d50c8e commit 21c66f4

File tree

11 files changed

+197
-169
lines changed

11 files changed

+197
-169
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/factor-sqlite/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ tracing = { workspace = true }
2121

2222
[dev-dependencies]
2323
spin-factors-test = { path = "../factors-test" }
24-
spin-sqlite = { path = "../sqlite" }
2524
tokio = { version = "1", features = ["macros", "rt"] }
2625

2726
[lints]

crates/factor-sqlite/src/host.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::collections::HashSet;
1+
use std::collections::{HashMap, HashSet};
22
use std::sync::Arc;
33

44
use async_trait::async_trait;
@@ -14,35 +14,28 @@ use crate::{Connection, ConnectionCreator};
1414

1515
pub struct InstanceState {
1616
allowed_databases: Arc<HashSet<String>>,
17+
/// A resource table of connections.
1718
connections: table::Table<Box<dyn Connection>>,
18-
get_connection_creator: ConnectionCreatorGetter,
19+
/// A map from database label to connection creators.
20+
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
1921
}
2022

21-
impl InstanceState {
22-
pub fn allowed_databases(&self) -> &HashSet<String> {
23-
&self.allowed_databases
24-
}
25-
}
26-
27-
/// A function that takes a database label and returns a connection creator, if one exists.
28-
pub type ConnectionCreatorGetter =
29-
Arc<dyn Fn(&str) -> Option<Arc<dyn ConnectionCreator>> + Send + Sync>;
30-
3123
impl InstanceState {
3224
/// Create a new `InstanceState`
3325
///
3426
/// Takes the list of allowed databases, and a function for getting a connection creator given a database label.
3527
pub fn new(
3628
allowed_databases: Arc<HashSet<String>>,
37-
get_connection_creator: ConnectionCreatorGetter,
29+
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
3830
) -> Self {
3931
Self {
4032
allowed_databases,
4133
connections: table::Table::new(256),
42-
get_connection_creator,
34+
connection_creators,
4335
}
4436
}
4537

38+
/// Get a connection for a given database label.
4639
fn get_connection(
4740
&self,
4841
connection: Resource<v2::Connection>,
@@ -52,6 +45,11 @@ impl InstanceState {
5245
.map(|conn| conn.as_ref())
5346
.ok_or(v2::Error::InvalidConnection)
5447
}
48+
49+
/// Get the set of allowed databases.
50+
pub fn allowed_databases(&self) -> &HashSet<String> {
51+
&self.allowed_databases
52+
}
5553
}
5654

5755
impl SelfInstanceBuilder for InstanceState {}
@@ -69,7 +67,9 @@ impl v2::HostConnection for InstanceState {
6967
if !self.allowed_databases.contains(&database) {
7068
return Err(v2::Error::AccessDenied);
7169
}
72-
let conn = (self.get_connection_creator)(&database)
70+
let conn = self
71+
.connection_creators
72+
.get(&database)
7373
.ok_or(v2::Error::NoSuchDatabase)?
7474
.create_connection(&database)
7575
.await?;

crates/factor-sqlite/src/lib.rs

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,15 @@ use spin_world::v2::sqlite as v2;
1414

1515
pub use runtime_config::RuntimeConfig;
1616

17+
#[derive(Default)]
1718
pub struct SqliteFactor {
18-
default_label_resolver: Arc<dyn DefaultLabelResolver>,
19+
_priv: (),
1920
}
2021

2122
impl SqliteFactor {
2223
/// Create a new `SqliteFactor`
23-
///
24-
/// Takes a `default_label_resolver` for how to handle when a database label doesn't
25-
/// have a corresponding runtime configuration.
26-
pub fn new(default_label_resolver: impl DefaultLabelResolver + 'static) -> Self {
27-
Self {
28-
default_label_resolver: Arc::new(default_label_resolver),
29-
}
24+
pub fn new() -> Self {
25+
Self { _priv: () }
3026
}
3127
}
3228

@@ -50,8 +46,8 @@ impl Factor for SqliteFactor {
5046
) -> anyhow::Result<Self::AppState> {
5147
let connection_creators = ctx
5248
.take_runtime_config()
53-
.map(|r| r.connection_creators)
54-
.unwrap_or_default();
49+
.unwrap_or_default()
50+
.connection_creators;
5551

5652
let allowed_databases = ctx
5753
.app()
@@ -69,19 +65,12 @@ impl Factor for SqliteFactor {
6965
))
7066
})
7167
.collect::<anyhow::Result<HashMap<_, _>>>()?;
72-
let resolver = self.default_label_resolver.clone();
73-
let get_connection_creator: host::ConnectionCreatorGetter = Arc::new(move |label| {
74-
connection_creators
75-
.get(label)
76-
.cloned()
77-
.or_else(|| resolver.default(label))
78-
});
7968

8069
ensure_allowed_databases_are_configured(&allowed_databases, |label| {
81-
get_connection_creator(label).is_some()
70+
connection_creators.contains_key(label)
8271
})?;
8372

84-
Ok(AppState::new(allowed_databases, get_connection_creator))
73+
Ok(AppState::new(allowed_databases, connection_creators))
8574
}
8675

8776
fn prepare<T: spin_factors::RuntimeFactors>(
@@ -94,10 +83,9 @@ impl Factor for SqliteFactor {
9483
.get(ctx.app_component().id())
9584
.cloned()
9685
.unwrap_or_default();
97-
let get_connection_creator = ctx.app_state().get_connection_creator.clone();
9886
Ok(InstanceState::new(
9987
allowed_databases,
100-
get_connection_creator,
88+
ctx.app_state().connection_creators.clone(),
10189
))
10290
}
10391
}
@@ -138,31 +126,23 @@ fn ensure_allowed_databases_are_configured(
138126
/// Metadata key for a list of allowed databases for a component.
139127
pub const ALLOWED_DATABASES_KEY: MetadataKey<Vec<String>> = MetadataKey::new("databases");
140128

141-
/// Resolves a label to a default connection creator.
142-
pub trait DefaultLabelResolver: Send + Sync {
143-
/// If there is no runtime configuration for a given database label, return a default connection creator.
144-
///
145-
/// If `Option::None` is returned, the database is not allowed.
146-
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionCreator>>;
147-
}
148-
149129
#[derive(Clone)]
150130
pub struct AppState {
151131
/// A map from component id to a set of allowed database labels.
152132
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
153-
/// A function for mapping from database name to a connection creator.
154-
get_connection_creator: host::ConnectionCreatorGetter,
133+
/// A mapping from database label to a connection creator.
134+
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
155135
}
156136

157137
impl AppState {
158138
/// Create a new `AppState`
159139
pub fn new(
160140
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
161-
get_connection_creator: host::ConnectionCreatorGetter,
141+
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
162142
) -> Self {
163143
Self {
164144
allowed_databases,
165-
get_connection_creator,
145+
connection_creators,
166146
}
167147
}
168148

@@ -173,7 +153,9 @@ impl AppState {
173153
&self,
174154
label: &str,
175155
) -> Option<Result<Box<dyn Connection>, v2::Error>> {
176-
let connection = (self.get_connection_creator)(label)?
156+
let connection = self
157+
.connection_creators
158+
.get(label)?
177159
.create_connection(label)
178160
.await;
179161
Some(connection)

crates/factor-sqlite/src/runtime_config.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::ConnectionCreator;
55
/// A runtime configuration for SQLite databases.
66
///
77
/// Maps database labels to connection creators.
8+
#[derive(Default)]
89
pub struct RuntimeConfig {
910
pub connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
1011
}

0 commit comments

Comments
 (0)