|
| 1 | +/* |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +//! Code related to creating signed URLs for logging in to RDS. |
| 7 | +//! |
| 8 | +//! For more information, see <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.Connecting.html> |
| 9 | +
|
| 10 | +use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; |
| 11 | +use aws_sigv4::http_request; |
| 12 | +use aws_sigv4::http_request::{SignableBody, SignableRequest, SigningSettings}; |
| 13 | +use aws_sigv4::sign::v4; |
| 14 | +use aws_smithy_runtime_api::box_error::BoxError; |
| 15 | +use aws_smithy_runtime_api::client::identity::Identity; |
| 16 | +use aws_types::region::Region; |
| 17 | +use std::fmt; |
| 18 | +use std::fmt::Debug; |
| 19 | +use std::time::Duration; |
| 20 | + |
| 21 | +const ACTION: &str = "connect"; |
| 22 | +const SERVICE: &str = "rds-db"; |
| 23 | + |
| 24 | +/// A signer that generates an auth token for a database. |
| 25 | +/// |
| 26 | +/// ## Example |
| 27 | +/// |
| 28 | +/// ```ignore |
| 29 | +/// use crate::auth_token::{AuthTokenGenerator, Config}; |
| 30 | +/// |
| 31 | +/// #[tokio::main] |
| 32 | +/// async fn main() { |
| 33 | +/// let cfg = aws_config::load_defaults(BehaviorVersion::latest()).await; |
| 34 | +/// let generator = AuthTokenGenerator::new( |
| 35 | +/// Config::builder() |
| 36 | +/// .hostname("zhessler-test-db.cp7a4mblr2ig.us-east-1.rds.amazonaws.com") |
| 37 | +/// .port(5432) |
| 38 | +/// .username("zhessler") |
| 39 | +/// .build() |
| 40 | +/// .expect("cfg is valid"), |
| 41 | +/// ); |
| 42 | +/// let token = generator.auth_token(&cfg).await.unwrap(); |
| 43 | +/// println!("{token}"); |
| 44 | +/// } |
| 45 | +/// ``` |
| 46 | +#[derive(Debug)] |
| 47 | +pub struct AuthTokenGenerator { |
| 48 | + config: Config, |
| 49 | +} |
| 50 | + |
| 51 | +/// An auth token usable as a password for an RDS database. |
| 52 | +/// |
| 53 | +/// This struct can be converted into a `&str` using the `Deref` trait or by calling `to_string()`. |
| 54 | +#[derive(Clone, Debug, PartialEq, Eq)] |
| 55 | +pub struct AuthToken { |
| 56 | + inner: String, |
| 57 | +} |
| 58 | + |
| 59 | +impl AuthToken { |
| 60 | + /// Return the auth token as a `&str`. |
| 61 | + #[must_use] |
| 62 | + pub fn as_str(&self) -> &str { |
| 63 | + &self.inner |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +impl fmt::Display for AuthToken { |
| 68 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 69 | + write!(f, "{}", self.inner) |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +impl AuthTokenGenerator { |
| 74 | + /// Given a `Config`, create a new RDS database login URL signer. |
| 75 | + pub fn new(config: Config) -> Self { |
| 76 | + Self { config } |
| 77 | + } |
| 78 | + |
| 79 | + /// Return a signed URL usable as an auth token. |
| 80 | + pub async fn auth_token( |
| 81 | + &self, |
| 82 | + config: &aws_types::sdk_config::SdkConfig, |
| 83 | + ) -> Result<AuthToken, BoxError> { |
| 84 | + let credentials = self |
| 85 | + .config |
| 86 | + .credentials() |
| 87 | + .or(config.credentials_provider()) |
| 88 | + .ok_or("credentials are required to create a signed URL for RDS")? |
| 89 | + .provide_credentials() |
| 90 | + .await?; |
| 91 | + let identity: Identity = credentials.into(); |
| 92 | + let region = self |
| 93 | + .config |
| 94 | + .region() |
| 95 | + .or(config.region()) |
| 96 | + .cloned() |
| 97 | + .unwrap_or_else(|| Region::new("us-east-1")); |
| 98 | + let time = config.time_source().ok_or("a time source is required")?; |
| 99 | + |
| 100 | + let mut signing_settings = SigningSettings::default(); |
| 101 | + signing_settings.expires_in = Some(Duration::from_secs( |
| 102 | + self.config.expires_in().unwrap_or(900).min(900), |
| 103 | + )); |
| 104 | + signing_settings.signature_location = http_request::SignatureLocation::QueryParams; |
| 105 | + |
| 106 | + let signing_params = v4::SigningParams::builder() |
| 107 | + .identity(&identity) |
| 108 | + .region(region.as_ref()) |
| 109 | + .name(SERVICE) |
| 110 | + .time(time.now()) |
| 111 | + .settings(signing_settings) |
| 112 | + .build()?; |
| 113 | + |
| 114 | + let url = format!( |
| 115 | + "https://{}:{}/?Action={}&DBUser={}", |
| 116 | + self.config.hostname(), |
| 117 | + self.config.port(), |
| 118 | + ACTION, |
| 119 | + self.config.username() |
| 120 | + ); |
| 121 | + let signable_request = |
| 122 | + SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::empty()) |
| 123 | + .expect("signable request"); |
| 124 | + |
| 125 | + let (signing_instructions, _signature) = |
| 126 | + http_request::sign(signable_request, &signing_params.into())?.into_parts(); |
| 127 | + |
| 128 | + let mut url = url::Url::parse(&url).unwrap(); |
| 129 | + for (name, value) in signing_instructions.params() { |
| 130 | + url.query_pairs_mut().append_pair(name, value); |
| 131 | + } |
| 132 | + let inner = url.to_string().split_off("https://".len()); |
| 133 | + |
| 134 | + Ok(AuthToken { inner }) |
| 135 | + } |
| 136 | +} |
| 137 | + |
| 138 | +/// Configuration for an RDS auth URL signer. |
| 139 | +#[derive(Debug, Clone)] |
| 140 | +pub struct Config { |
| 141 | + /// The AWS credentials to sign requests with. |
| 142 | + /// |
| 143 | + /// Uses the default credential provider chain if not specified. |
| 144 | + credentials: Option<SharedCredentialsProvider>, |
| 145 | + |
| 146 | + /// The hostname of the database to connect to. |
| 147 | + hostname: String, |
| 148 | + |
| 149 | + /// The port number the database is listening on. |
| 150 | + port: u64, |
| 151 | + |
| 152 | + /// The region the database is located in. Uses the region inferred from the runtime if omitted. |
| 153 | + region: Option<Region>, |
| 154 | + |
| 155 | + /// The username to login as. |
| 156 | + username: String, |
| 157 | + |
| 158 | + /// The number of seconds the signed URL should be valid for. |
| 159 | + /// |
| 160 | + /// Maxes at 900 seconds. |
| 161 | + expires_in: Option<u64>, |
| 162 | +} |
| 163 | + |
| 164 | +impl Config { |
| 165 | + /// Create a new `SignerConfigBuilder`. |
| 166 | + pub fn builder() -> ConfigBuilder { |
| 167 | + ConfigBuilder::default() |
| 168 | + } |
| 169 | + |
| 170 | + /// The AWS credentials to sign requests with. |
| 171 | + pub fn credentials(&self) -> Option<SharedCredentialsProvider> { |
| 172 | + self.credentials.clone() |
| 173 | + } |
| 174 | + |
| 175 | + /// The hostname of the database to connect to. |
| 176 | + pub fn hostname(&self) -> &str { |
| 177 | + &self.hostname |
| 178 | + } |
| 179 | + |
| 180 | + /// The port number the database is listening on. |
| 181 | + pub fn port(&self) -> u64 { |
| 182 | + self.port |
| 183 | + } |
| 184 | + |
| 185 | + /// The region to sign requests with. |
| 186 | + pub fn region(&self) -> Option<&Region> { |
| 187 | + self.region.as_ref() |
| 188 | + } |
| 189 | + |
| 190 | + /// The DB username to login as. |
| 191 | + pub fn username(&self) -> &str { |
| 192 | + &self.username |
| 193 | + } |
| 194 | + |
| 195 | + /// The number of seconds the signed URL should be valid for. |
| 196 | + /// |
| 197 | + /// Maxes out at 900 seconds. |
| 198 | + pub fn expires_in(&self) -> Option<u64> { |
| 199 | + self.expires_in |
| 200 | + } |
| 201 | +} |
| 202 | + |
| 203 | +/// A builder for [`Config`]s. |
| 204 | +#[derive(Debug, Default)] |
| 205 | +pub struct ConfigBuilder { |
| 206 | + /// The AWS credentials to create the auth token with. |
| 207 | + /// |
| 208 | + /// Uses the default credential provider chain if not specified. |
| 209 | + credentials: Option<SharedCredentialsProvider>, |
| 210 | + |
| 211 | + /// The hostname of the database to connect to. |
| 212 | + hostname: Option<String>, |
| 213 | + |
| 214 | + /// The port number the database is listening on. |
| 215 | + port: Option<u64>, |
| 216 | + |
| 217 | + /// The region the database is located in. Uses the region inferred from the runtime if omitted. |
| 218 | + region: Option<Region>, |
| 219 | + |
| 220 | + /// The database username to login as. |
| 221 | + username: Option<String>, |
| 222 | + |
| 223 | + /// The number of seconds the auth token should be valid for. |
| 224 | + expires_in: Option<u64>, |
| 225 | +} |
| 226 | + |
| 227 | +impl ConfigBuilder { |
| 228 | + /// The AWS credentials to create the auth token with. |
| 229 | + /// |
| 230 | + /// Uses the default credential provider chain if not specified. |
| 231 | + pub fn credentials(mut self, credentials: impl ProvideCredentials + 'static) -> Self { |
| 232 | + self.credentials = Some(SharedCredentialsProvider::new(credentials)); |
| 233 | + self |
| 234 | + } |
| 235 | + |
| 236 | + /// The hostname of the database to connect to. |
| 237 | + pub fn hostname(mut self, hostname: impl Into<String>) -> Self { |
| 238 | + self.hostname = Some(hostname.into()); |
| 239 | + self |
| 240 | + } |
| 241 | + |
| 242 | + /// The port number the database is listening on. |
| 243 | + pub fn port(mut self, port: u64) -> Self { |
| 244 | + self.port = Some(port); |
| 245 | + self |
| 246 | + } |
| 247 | + |
| 248 | + /// The region the database is located in. Uses the region inferred from the runtime if omitted. |
| 249 | + pub fn region(mut self, region: Region) -> Self { |
| 250 | + self.region = Some(region); |
| 251 | + self |
| 252 | + } |
| 253 | + |
| 254 | + /// The database username to login as. |
| 255 | + pub fn username(mut self, username: impl Into<String>) -> Self { |
| 256 | + self.username = Some(username.into()); |
| 257 | + self |
| 258 | + } |
| 259 | + |
| 260 | + /// The number of seconds the signed URL should be valid for. |
| 261 | + /// |
| 262 | + /// Maxes out at 900 seconds. |
| 263 | + pub fn expires_in(mut self, expires_in: u64) -> Self { |
| 264 | + self.expires_in = Some(expires_in); |
| 265 | + self |
| 266 | + } |
| 267 | + |
| 268 | + /// Consume this builder, returning an error if required fields are missing. |
| 269 | + /// Otherwise, return a new `SignerConfig`. |
| 270 | + pub fn build(self) -> Result<Config, BoxError> { |
| 271 | + Ok(Config { |
| 272 | + credentials: self.credentials, |
| 273 | + hostname: self.hostname.ok_or("A hostname is required")?, |
| 274 | + port: self.port.ok_or("a port is required")?, |
| 275 | + region: self.region, |
| 276 | + username: self.username.ok_or("a username is required")?, |
| 277 | + expires_in: self.expires_in, |
| 278 | + }) |
| 279 | + } |
| 280 | +} |
| 281 | + |
| 282 | +#[cfg(test)] |
| 283 | +mod test { |
| 284 | + use super::{AuthTokenGenerator, Config}; |
| 285 | + use aws_credential_types::provider::SharedCredentialsProvider; |
| 286 | + use aws_credential_types::Credentials; |
| 287 | + use aws_smithy_async::test_util::ManualTimeSource; |
| 288 | + use aws_types::region::Region; |
| 289 | + use aws_types::SdkConfig; |
| 290 | + use std::time::{Duration, UNIX_EPOCH}; |
| 291 | + |
| 292 | + #[tokio::test] |
| 293 | + async fn signing_works() { |
| 294 | + let time_source = ManualTimeSource::new(UNIX_EPOCH + Duration::from_secs(1724709600)); |
| 295 | + let sdk_config = SdkConfig::builder() |
| 296 | + .credentials_provider(SharedCredentialsProvider::new(Credentials::new( |
| 297 | + "akid", "secret", None, None, "test", |
| 298 | + ))) |
| 299 | + .time_source(time_source) |
| 300 | + .build(); |
| 301 | + let signer = AuthTokenGenerator::new( |
| 302 | + Config::builder() |
| 303 | + .hostname("prod-instance.us-east-1.rds.amazonaws.com") |
| 304 | + .port(3306) |
| 305 | + .region(Region::new("us-east-1")) |
| 306 | + .username("peccy") |
| 307 | + .build() |
| 308 | + .unwrap(), |
| 309 | + ); |
| 310 | + |
| 311 | + let signed_url = signer.auth_token(&sdk_config).await.unwrap(); |
| 312 | + assert_eq!(signed_url.as_str(), "prod-instance.us-east-1.rds.amazonaws.com:3306/?Action=connect&DBUser=peccy&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=akid%2F20240826%2Fus-east-1%2Frds-db%2Faws4_request&X-Amz-Date=20240826T220000Z&X-Amz-Expires=900&X-Amz-SignedHeaders=host&X-Amz-Signature=dd0cba843009474347af724090233265628ace491ea17ce3eb3da098b983ad89"); |
| 313 | + } |
| 314 | +} |
0 commit comments