Skip to content

Commit 958399c

Browse files
authored
feat: RDS custom database name (#1651)
* feat: RDS custom database name Allow RDS instances to have a custom DB name. Or default to the project name. * test: update * refactor: only use db_name for RDS types * refactor: comments * refactor: create DB for local runs * refactor: fix variable name * bug: have local mongo default to admin * refactor: staging test * refactor: undo test changes
1 parent 5742dc8 commit 958399c

File tree

10 files changed

+58
-25
lines changed

10 files changed

+58
-25
lines changed

cargo-shuttle/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,7 @@ impl Shuttle {
12061206
prov.provision_database(Request::new(DatabaseRequest {
12071207
project_name: project_name.to_string(),
12081208
db_type: Some(db_type.into()),
1209+
db_name: config.db_name,
12091210
}))
12101211
.await?
12111212
.into_inner()

cargo-shuttle/src/provisioner_server.rs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crossterm::{
1818
use futures::StreamExt;
1919
use portpicker::pick_unused_port;
2020
use shuttle_common::{
21-
database::{AwsRdsEngine, SharedEngine},
21+
database::{self, AwsRdsEngine, SharedEngine},
2222
ContainerRequest, ContainerResponse, Secret,
2323
};
2424
use shuttle_proto::provisioner::{
@@ -157,20 +157,26 @@ impl LocalProvisioner {
157157
&self,
158158
project_name: &str,
159159
db_type: Type,
160+
db_name: Option<String>,
160161
) -> Result<DatabaseResponse, Status> {
161162
trace!("getting sql string for project '{project_name}'");
162163

164+
let database_name = match db_type {
165+
database::Type::AwsRds(_) => db_name.unwrap_or_else(|| project_name.to_string()),
166+
database::Type::Shared(SharedEngine::MongoDb) => "admin".to_string(),
167+
_ => project_name.to_string(),
168+
};
169+
163170
let EngineConfig {
164171
r#type,
165172
image,
166173
engine,
167174
username,
168175
password,
169-
database_name,
170176
port,
171177
env,
172178
is_ready_cmd,
173-
} = db_type_to_config(db_type);
179+
} = db_type_to_config(db_type, &database_name);
174180
let container_name = format!("shuttle_{project_name}_{type}");
175181

176182
let container = self
@@ -320,12 +326,13 @@ impl Provisioner for LocalProvisioner {
320326
let DatabaseRequest {
321327
project_name,
322328
db_type,
329+
db_name,
323330
} = request.into_inner();
324331

325332
let db_type: Option<Type> = db_type.unwrap().into();
326333

327334
let res = self
328-
.get_db_connection_string(&project_name, db_type.unwrap())
335+
.get_db_connection_string(&project_name, db_type.unwrap(), db_name)
329336
.await?;
330337

331338
Ok(Response::new(res))
@@ -387,23 +394,24 @@ struct EngineConfig {
387394
engine: String,
388395
username: String,
389396
password: Secret<String>,
390-
database_name: String,
391397
port: String,
392398
env: Option<Vec<String>>,
393399
is_ready_cmd: Vec<String>,
394400
}
395401

396-
fn db_type_to_config(db_type: Type) -> EngineConfig {
402+
fn db_type_to_config(db_type: Type, database_name: &str) -> EngineConfig {
397403
match db_type {
398404
Type::Shared(SharedEngine::Postgres) => EngineConfig {
399405
r#type: "shared_postgres".to_string(),
400406
image: "docker.io/library/postgres:14".to_string(),
401407
engine: "postgres".to_string(),
402408
username: "postgres".to_string(),
403409
password: "postgres".to_string().into(),
404-
database_name: "postgres".to_string(),
405410
port: "5432/tcp".to_string(),
406-
env: Some(vec!["POSTGRES_PASSWORD=postgres".to_string()]),
411+
env: Some(vec![
412+
"POSTGRES_PASSWORD=postgres".to_string(),
413+
format!("POSTGRES_DB={database_name}"),
414+
]),
407415
is_ready_cmd: vec![
408416
"/bin/sh".to_string(),
409417
"-c".to_string(),
@@ -416,11 +424,11 @@ fn db_type_to_config(db_type: Type) -> EngineConfig {
416424
engine: "mongodb".to_string(),
417425
username: "mongodb".to_string(),
418426
password: "password".to_string().into(),
419-
database_name: "admin".to_string(),
420427
port: "27017/tcp".to_string(),
421428
env: Some(vec![
422429
"MONGO_INITDB_ROOT_USERNAME=mongodb".to_string(),
423430
"MONGO_INITDB_ROOT_PASSWORD=password".to_string(),
431+
format!("MONGO_INITDB_DATABASE={database_name}"),
424432
]),
425433
is_ready_cmd: vec![
426434
"mongosh".to_string(),
@@ -435,9 +443,11 @@ fn db_type_to_config(db_type: Type) -> EngineConfig {
435443
engine: "postgres".to_string(),
436444
username: "postgres".to_string(),
437445
password: "postgres".to_string().into(),
438-
database_name: "postgres".to_string(),
439446
port: "5432/tcp".to_string(),
440-
env: Some(vec!["POSTGRES_PASSWORD=postgres".to_string()]),
447+
env: Some(vec![
448+
"POSTGRES_PASSWORD=postgres".to_string(),
449+
format!("POSTGRES_DB={database_name}"),
450+
]),
441451
is_ready_cmd: vec![
442452
"/bin/sh".to_string(),
443453
"-c".to_string(),
@@ -450,9 +460,11 @@ fn db_type_to_config(db_type: Type) -> EngineConfig {
450460
engine: "mariadb".to_string(),
451461
username: "root".to_string(),
452462
password: "mariadb".to_string().into(),
453-
database_name: "mysql".to_string(),
454463
port: "3306/tcp".to_string(),
455-
env: Some(vec!["MARIADB_ROOT_PASSWORD=mariadb".to_string()]),
464+
env: Some(vec![
465+
"MARIADB_ROOT_PASSWORD=mariadb".to_string(),
466+
format!("MARIADB_DATABASE={database_name}"),
467+
]),
456468
is_ready_cmd: vec![
457469
"mysql".to_string(),
458470
"-pmariadb".to_string(),
@@ -467,9 +479,11 @@ fn db_type_to_config(db_type: Type) -> EngineConfig {
467479
engine: "mysql".to_string(),
468480
username: "root".to_string(),
469481
password: "mysql".to_string().into(),
470-
database_name: "mysql".to_string(),
471482
port: "3306/tcp".to_string(),
472-
env: Some(vec!["MYSQL_ROOT_PASSWORD=mysql".to_string()]),
483+
env: Some(vec![
484+
"MYSQL_ROOT_PASSWORD=mysql".to_string(),
485+
format!("MYSQL_DATABASE={database_name}"),
486+
]),
473487
is_ready_cmd: vec![
474488
"mysql".to_string(),
475489
"-pmysql".to_string(),

common/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ impl AsRef<str> for ApiKey {
9292
#[derive(Deserialize, Serialize, Default)]
9393
pub struct DbInput {
9494
pub local_uri: Option<String>,
95+
/// Override the default db name. Only applies to RDS.
96+
pub db_name: Option<String>,
9597
}
9698

9799
/// The output produced by Shuttle DB resources

deployer/src/deployment/run.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ async fn provision(
509509
match shuttle_resource.r#type {
510510
resource::Type::Database(db_type) => {
511511
// no config fields are used yet, but verify the format anyways
512-
let _config: DbInput = serde_json::from_value(shuttle_resource.config.clone())
512+
let config: DbInput = serde_json::from_value(shuttle_resource.config.clone())
513513
.context("deserializing resource config")?;
514514

515515
let output = get_cached_output(&shuttle_resource, prev_resources.as_slice());
@@ -521,6 +521,7 @@ async fn provision(
521521
let mut req = Request::new(DatabaseRequest {
522522
project_name: project_name.to_string(),
523523
db_type: Some(db_type.into()),
524+
db_name: config.db_name,
524525
// other relevant config fields would go here
525526
});
526527
req.extensions_mut().insert(claim.clone());

deployer/src/persistence/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ impl ResourceManager for Persistence {
480480
let mut req = Request::new(DatabaseRequest {
481481
project_name,
482482
db_type: Some(db_type.into()),
483+
db_name: None,
483484
});
484485
req.extensions_mut().insert(claim.clone());
485486

proto/provisioner.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ message DatabaseRequest {
1313
Shared Shared = 10;
1414
AwsRds AwsRds = 11;
1515
};
16+
// Override the default db name. Only applies to RDS.
17+
optional string db_name = 2;
1618
}
1719

1820
message Shared {

proto/src/generated/provisioner.rs

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

provisioner/src/lib.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ impl ShuttleProvisioner {
265265
&self,
266266
project_name: &str,
267267
engine: aws_rds::Engine,
268+
database_name: &Option<String>,
268269
) -> Result<DatabaseResponse, Error> {
269270
let client = &self.rds_client;
270271

@@ -287,13 +288,9 @@ impl ShuttleProvisioner {
287288
if let ModifyDBInstanceError::DbInstanceNotFoundFault(_) = err.err() {
288289
debug!("creating new AWS RDS {instance_name}");
289290

290-
// The engine display impl is used for both the engine and the database name,
291-
// but for mysql the engine name is an invalid database name.
292-
let db_name = if let aws_rds::Engine::Mysql(_) = engine {
293-
"msql".to_string()
294-
} else {
295-
engine.to_string()
296-
};
291+
let db_name = database_name
292+
.to_owned()
293+
.unwrap_or_else(|| project_name.to_string());
297294

298295
client
299296
.create_db_instance()
@@ -527,8 +524,12 @@ impl Provisioner for ShuttleProvisioner {
527524
}
528525
}
529526

530-
self.request_aws_rds(&request.project_name, engine.expect("engine to be set"))
531-
.await?
527+
self.request_aws_rds(
528+
&request.project_name,
529+
engine.expect("engine to be set"),
530+
&request.db_name,
531+
)
532+
.await?
532533
}
533534
};
534535

provisioner/tests/provisioner.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ mod needs_docker {
9393
db_type: Some(DbType::AwsRds(AwsRds {
9494
engine: Some(Engine::Postgres(Default::default())),
9595
})),
96+
db_name: Some("custom-name".to_string()),
9697
});
9798

9899
// Add a claim that only allows for one RDS - the one that will be returned by r-r

resources/aws-rds/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ macro_rules! aws_engine {
2424

2525
self
2626
}
27+
28+
/// Use something other than the project name as the DB name
29+
pub fn database_name(mut self, database_name: &str) -> Self {
30+
self.0.db_name = Some(database_name.to_string());
31+
32+
self
33+
}
2734
}
2835

2936
#[cfg(feature = $feature)]

0 commit comments

Comments
 (0)