Skip to content

Commit 30bcf5d

Browse files
authored
refactor(sql): Allow multiple drivers at the same time (#1838)
* refactor(sql): Allow multiple drivers at the same time * fmt * remove default feature comment [skip ci] * what was that doing there [skip ci] * disable public methods for now
1 parent 6857993 commit 30bcf5d

File tree

11 files changed

+615
-393
lines changed

11 files changed

+615
-393
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
sql: patch
3+
---
4+
5+
It is now possible to enable multiple SQL backends at the same time. There will be no compile error anymore if no backends are enabled!

.github/workflows/lint-rust.yml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,7 @@ jobs:
148148
- uses: Swatinem/rust-cache@v2
149149

150150
- name: clippy ${{ matrix.package }}
151-
if: matrix.package != 'tauri-plugin-sql'
152151
run: cargo clippy --package ${{ matrix.package }} --all-targets -- -D warnings
153152

154-
- name: clippy ${{ matrix.package }} mysql
155-
if: matrix.package == 'tauri-plugin-sql'
156-
run: cargo clippy --package ${{ matrix.package }} --all-targets --no-default-features --features mysql -- -D warnings
157-
158-
- name: clippy ${{ matrix.package }} postgres
159-
if: matrix.package == 'tauri-plugin-sql'
160-
run: cargo clippy --package ${{ matrix.package }} --all-targets --no-default-features --features postgres -- -D warnings
153+
- name: clippy ${{ matrix.package }} --all-features
154+
run: cargo clippy --package ${{ matrix.package }} --all-targets --all-features -- -D warnings

.github/workflows/test-rust.yml

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,9 @@ jobs:
215215
run: cargo +stable install cross --git https://github.com/cross-rs/cross
216216

217217
- name: test ${{ matrix.package }}
218-
if: matrix.package != 'tauri-plugin-sql' && matrix.package != 'tauri-plugin-http'
218+
if: matrix.package != 'tauri-plugin-http'
219219
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --all-features
220220

221221
- name: test ${{ matrix.package }}
222222
if: matrix.package == 'tauri-plugin-http'
223223
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets
224-
225-
- name: test ${{ matrix.package }} sqlite
226-
if: matrix.package == 'tauri-plugin-sql'
227-
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --features sqlite
228-
229-
- name: test ${{ matrix.package }} mysql
230-
if: matrix.package == 'tauri-plugin-sql'
231-
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --features mysql
232-
233-
- name: test ${{ matrix.package }} postgres
234-
if: matrix.package == 'tauri-plugin-sql'
235-
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --features postgres

plugins/localhost/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ impl Builder {
7474
let asset_resolver = app.asset_resolver();
7575
std::thread::spawn(move || {
7676
let server =
77-
Server::http(&format!("localhost:{port}")).expect("Unable to spawn server");
77+
Server::http(format!("localhost:{port}")).expect("Unable to spawn server");
7878
for req in server.incoming_requests() {
7979
let path = req
8080
.url()

plugins/sql/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ indexmap = { version = "2", features = ["serde"] }
4040
sqlite = ["sqlx/sqlite", "sqlx/runtime-tokio"]
4141
mysql = ["sqlx/mysql", "sqlx/runtime-tokio-rustls"]
4242
postgres = ["sqlx/postgres", "sqlx/runtime-tokio-rustls"]
43+
# TODO: bundled-cipher etc

plugins/sql/src/commands.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2+
// SPDX-License-Identifier: Apache-2.0
3+
// SPDX-License-Identifier: MIT
4+
5+
use indexmap::IndexMap;
6+
use serde_json::Value as JsonValue;
7+
use sqlx::migrate::Migrator;
8+
use tauri::{command, AppHandle, Runtime, State};
9+
10+
use crate::{DbInstances, DbPool, Error, LastInsertId, Migrations};
11+
12+
#[command]
13+
pub(crate) async fn load<R: Runtime>(
14+
app: AppHandle<R>,
15+
db_instances: State<'_, DbInstances>,
16+
migrations: State<'_, Migrations>,
17+
db: String,
18+
) -> Result<String, crate::Error> {
19+
let pool = DbPool::connect(&db, &app).await?;
20+
21+
if let Some(migrations) = migrations.0.lock().await.remove(&db) {
22+
let migrator = Migrator::new(migrations).await?;
23+
pool.migrate(&migrator).await?;
24+
}
25+
26+
db_instances.0.lock().await.insert(db.clone(), pool);
27+
28+
Ok(db)
29+
}
30+
31+
/// Allows the database connection(s) to be closed; if no database
32+
/// name is passed in then _all_ database connection pools will be
33+
/// shut down.
34+
#[command]
35+
pub(crate) async fn close(
36+
db_instances: State<'_, DbInstances>,
37+
db: Option<String>,
38+
) -> Result<bool, crate::Error> {
39+
let mut instances = db_instances.0.lock().await;
40+
41+
let pools = if let Some(db) = db {
42+
vec![db]
43+
} else {
44+
instances.keys().cloned().collect()
45+
};
46+
47+
for pool in pools {
48+
let db = instances
49+
.get_mut(&pool)
50+
.ok_or(Error::DatabaseNotLoaded(pool))?;
51+
db.close().await;
52+
}
53+
54+
Ok(true)
55+
}
56+
57+
/// Execute a command against the database
58+
#[command]
59+
pub(crate) async fn execute(
60+
db_instances: State<'_, DbInstances>,
61+
db: String,
62+
query: String,
63+
values: Vec<JsonValue>,
64+
) -> Result<(u64, LastInsertId), crate::Error> {
65+
let mut instances = db_instances.0.lock().await;
66+
67+
let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?;
68+
db.execute(query, values).await
69+
}
70+
71+
#[command]
72+
pub(crate) async fn select(
73+
db_instances: State<'_, DbInstances>,
74+
db: String,
75+
query: String,
76+
values: Vec<JsonValue>,
77+
) -> Result<Vec<IndexMap<String, JsonValue>>, crate::Error> {
78+
let mut instances = db_instances.0.lock().await;
79+
80+
let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?;
81+
db.select(query, values).await
82+
}

plugins/sql/src/decode/mod.rs

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,8 @@
33
// SPDX-License-Identifier: MIT
44

55
#[cfg(feature = "mysql")]
6-
mod mysql;
6+
pub(crate) mod mysql;
77
#[cfg(feature = "postgres")]
8-
mod postgres;
8+
pub(crate) mod postgres;
99
#[cfg(feature = "sqlite")]
10-
mod sqlite;
11-
12-
#[cfg(feature = "mysql")]
13-
pub(crate) use mysql::to_json;
14-
15-
#[cfg(feature = "postgres")]
16-
pub(crate) use postgres::to_json;
17-
18-
#[cfg(feature = "sqlite")]
19-
pub(crate) use sqlite::to_json;
10+
pub(crate) mod sqlite;

plugins/sql/src/error.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2+
// SPDX-License-Identifier: Apache-2.0
3+
// SPDX-License-Identifier: MIT
4+
5+
use serde::{Serialize, Serializer};
6+
7+
#[derive(Debug, thiserror::Error)]
8+
pub enum Error {
9+
#[error(transparent)]
10+
Sql(#[from] sqlx::Error),
11+
#[error(transparent)]
12+
Migration(#[from] sqlx::migrate::MigrateError),
13+
#[error("invalid connection url: {0}")]
14+
InvalidDbUrl(String),
15+
#[error("database {0} not loaded")]
16+
DatabaseNotLoaded(String),
17+
#[error("unsupported datatype: {0}")]
18+
UnsupportedDatatype(String),
19+
}
20+
21+
impl Serialize for Error {
22+
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
23+
where
24+
S: Serializer,
25+
{
26+
serializer.serialize_str(self.to_string().as_ref())
27+
}
28+
}

plugins/sql/src/lib.rs

Lines changed: 164 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,168 @@
1111
html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
1212
)]
1313

14-
#[cfg(any(
15-
all(feature = "sqlite", feature = "mysql"),
16-
all(feature = "sqlite", feature = "postgres"),
17-
all(feature = "mysql", feature = "postgres")
18-
))]
19-
compile_error!(
20-
"Only one database driver can be enabled. Set the feature flag for the driver of your choice."
21-
);
22-
23-
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
24-
compile_error!(
25-
"Database driver not defined. Please set the feature flag for the driver of your choice."
26-
);
27-
14+
mod commands;
2815
mod decode;
29-
mod plugin;
30-
pub use plugin::*;
16+
mod error;
17+
mod wrapper;
18+
19+
pub use error::Error;
20+
pub use wrapper::DbPool;
21+
22+
use futures_core::future::BoxFuture;
23+
use serde::{Deserialize, Serialize};
24+
use sqlx::{
25+
error::BoxDynError,
26+
migrate::{Migration as SqlxMigration, MigrationSource, MigrationType, Migrator},
27+
};
28+
use tauri::{
29+
plugin::{Builder as PluginBuilder, TauriPlugin},
30+
Manager, RunEvent, Runtime,
31+
};
32+
use tokio::sync::Mutex;
33+
34+
use std::collections::HashMap;
35+
36+
#[derive(Default)]
37+
pub struct DbInstances(pub Mutex<HashMap<String, DbPool>>);
38+
39+
#[derive(Serialize)]
40+
#[serde(untagged)]
41+
pub(crate) enum LastInsertId {
42+
#[cfg(feature = "sqlite")]
43+
Sqlite(i64),
44+
#[cfg(feature = "mysql")]
45+
MySql(u64),
46+
#[cfg(feature = "postgres")]
47+
Postgres(()),
48+
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
49+
None,
50+
}
51+
52+
struct Migrations(Mutex<HashMap<String, MigrationList>>);
53+
54+
#[derive(Default, Clone, Deserialize)]
55+
pub struct PluginConfig {
56+
#[serde(default)]
57+
preload: Vec<String>,
58+
}
59+
60+
#[derive(Debug)]
61+
pub enum MigrationKind {
62+
Up,
63+
Down,
64+
}
65+
66+
impl From<MigrationKind> for MigrationType {
67+
fn from(kind: MigrationKind) -> Self {
68+
match kind {
69+
MigrationKind::Up => Self::ReversibleUp,
70+
MigrationKind::Down => Self::ReversibleDown,
71+
}
72+
}
73+
}
74+
75+
/// A migration definition.
76+
#[derive(Debug)]
77+
pub struct Migration {
78+
pub version: i64,
79+
pub description: &'static str,
80+
pub sql: &'static str,
81+
pub kind: MigrationKind,
82+
}
83+
84+
#[derive(Debug)]
85+
struct MigrationList(Vec<Migration>);
86+
87+
impl MigrationSource<'static> for MigrationList {
88+
fn resolve(self) -> BoxFuture<'static, std::result::Result<Vec<SqlxMigration>, BoxDynError>> {
89+
Box::pin(async move {
90+
let mut migrations = Vec::new();
91+
for migration in self.0 {
92+
if matches!(migration.kind, MigrationKind::Up) {
93+
migrations.push(SqlxMigration::new(
94+
migration.version,
95+
migration.description.into(),
96+
migration.kind.into(),
97+
migration.sql.into(),
98+
false,
99+
));
100+
}
101+
}
102+
Ok(migrations)
103+
})
104+
}
105+
}
106+
107+
/// Tauri SQL plugin builder.
108+
#[derive(Default)]
109+
pub struct Builder {
110+
migrations: Option<HashMap<String, MigrationList>>,
111+
}
112+
113+
impl Builder {
114+
pub fn new() -> Self {
115+
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
116+
eprintln!("No sql driver enabled. Please set at least one of the \"sqlite\", \"mysql\", \"postgres\" feature flags.");
117+
118+
Self::default()
119+
}
120+
121+
/// Add migrations to a database.
122+
#[must_use]
123+
pub fn add_migrations(mut self, db_url: &str, migrations: Vec<Migration>) -> Self {
124+
self.migrations
125+
.get_or_insert(Default::default())
126+
.insert(db_url.to_string(), MigrationList(migrations));
127+
self
128+
}
129+
130+
pub fn build<R: Runtime>(mut self) -> TauriPlugin<R, Option<PluginConfig>> {
131+
PluginBuilder::<R, Option<PluginConfig>>::new("sql")
132+
.invoke_handler(tauri::generate_handler![
133+
commands::load,
134+
commands::execute,
135+
commands::select,
136+
commands::close
137+
])
138+
.setup(|app, api| {
139+
let config = api.config().clone().unwrap_or_default();
140+
141+
tauri::async_runtime::block_on(async move {
142+
let instances = DbInstances::default();
143+
let mut lock = instances.0.lock().await;
144+
145+
for db in config.preload {
146+
let pool = DbPool::connect(&db, app).await?;
147+
148+
if let Some(migrations) = self.migrations.as_mut().unwrap().remove(&db) {
149+
let migrator = Migrator::new(migrations).await?;
150+
pool.migrate(&migrator).await?;
151+
}
152+
153+
lock.insert(db, pool);
154+
}
155+
drop(lock);
156+
157+
app.manage(instances);
158+
app.manage(Migrations(Mutex::new(
159+
self.migrations.take().unwrap_or_default(),
160+
)));
161+
162+
Ok(())
163+
})
164+
})
165+
.on_event(|app, event| {
166+
if let RunEvent::Exit = event {
167+
tauri::async_runtime::block_on(async move {
168+
let instances = &*app.state::<DbInstances>();
169+
let instances = instances.0.lock().await;
170+
for value in instances.values() {
171+
value.close().await;
172+
}
173+
});
174+
}
175+
})
176+
.build()
177+
}
178+
}

0 commit comments

Comments
 (0)