Skip to content

Commit 8e5a859

Browse files
committed
Support encryption options
1 parent 23701a2 commit 8e5a859

File tree

5 files changed

+101
-32
lines changed

5 files changed

+101
-32
lines changed

sqlx-core/src/mssql/connection/establish.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,14 @@ impl MssqlConnection {
1818
// TODO: Encryption
1919
// TODO: Send the version of SQLx over
2020

21-
let encryption = match options.encrypt {
22-
Some(true) => Encrypt::Required,
23-
Some(false) => Encrypt::NotSupported,
24-
None => Encrypt::On,
25-
};
26-
27-
log::debug!("Sending T-SQL PRELOGIN with encryption: {:?}", encryption);
21+
log::debug!("Sending T-SQL PRELOGIN with encryption: {:?}", options.encrypt);
2822

2923
stream
3024
.write_packet_and_flush(
3125
PacketType::PreLogin,
3226
PreLogin {
3327
version: Version::default(),
34-
encryption,
28+
encryption: options.encrypt,
3529
instance: options.instance.clone(),
3630

3731
..Default::default()
@@ -47,7 +41,7 @@ impl MssqlConnection {
4741
Encrypt::Required | Encrypt::On
4842
) {
4943
stream.setup_encryption().await?;
50-
} else if encryption == Encrypt::Required {
44+
} else if options.encrypt == Encrypt::Required {
5145
return Err(Error::Tls(Box::new(std::io::Error::new(
5246
std::io::ErrorKind::Other,
5347
"TLS encryption required but not supported by server",

sqlx-core/src/mssql/connection/stream.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,10 @@ impl MssqlStream {
244244

245245
pub(crate) async fn setup_encryption(&mut self) -> Result<(), Error> {
246246
let tls_config = TlsConfig {
247-
accept_invalid_certs: true,
248-
hostname: &self.options.host,
249-
accept_invalid_hostnames: true,
250-
root_cert_path: None,
247+
accept_invalid_certs: self.options.trust_server_certificate,
248+
hostname: self.options.hostname_in_certificate.as_deref().unwrap_or(&self.options.host),
249+
accept_invalid_hostnames: self.options.hostname_in_certificate.is_none(),
250+
root_cert_path: self.options.ssl_root_cert.as_ref(),
251251
client_cert_path: None,
252252
client_key_path: None,
253253
};

sqlx-core/src/mssql/options/mod.rs

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::connection::LogSettings;
1+
use std::path::Path;
2+
3+
use crate::{connection::LogSettings, net::CertificateInput};
4+
use super::protocol::pre_login::Encrypt;
25

36
mod connect;
47
mod parse;
@@ -27,7 +30,10 @@ pub struct MssqlConnectOptions {
2730
pub(crate) language: String,
2831
/// Size in bytes of TDS packets to exchange with the server
2932
pub(crate) requested_packet_size: u32,
30-
pub(crate) encrypt: Option<bool>, // Added field
33+
pub(crate) encrypt: Encrypt,
34+
pub(crate) trust_server_certificate: bool,
35+
pub(crate) hostname_in_certificate: Option<String>,
36+
pub(crate) ssl_root_cert: Option<CertificateInput>,
3137
}
3238

3339
impl Default for MssqlConnectOptions {
@@ -49,12 +55,15 @@ impl MssqlConnectOptions {
4955
requested_packet_size: 4096,
5056
client_program_version: 0,
5157
client_pid: 0,
52-
hostname: "".to_string(),
53-
app_name: "".to_string(),
54-
server_name: "".to_string(),
55-
client_interface_name: "".to_string(),
56-
language: "".to_string(),
57-
encrypt: None,
58+
hostname: String::new(),
59+
app_name: String::new(),
60+
server_name: String::new(),
61+
client_interface_name: String::new(),
62+
language: String::new(),
63+
encrypt: Encrypt::On,
64+
trust_server_certificate: true,
65+
hostname_in_certificate: None,
66+
ssl_root_cert: None,
5867
}
5968
}
6069

@@ -89,12 +98,12 @@ impl MssqlConnectOptions {
8998
}
9099

91100
pub fn client_program_version(mut self, client_program_version: u32) -> Self {
92-
self.client_program_version = client_program_version.to_owned();
101+
self.client_program_version = client_program_version;
93102
self
94103
}
95104

96105
pub fn client_pid(mut self, client_pid: u32) -> Self {
97-
self.client_pid = client_pid.to_owned();
106+
self.client_pid = client_pid;
98107
self
99108
}
100109

@@ -134,8 +143,26 @@ impl MssqlConnectOptions {
134143
}
135144
}
136145

137-
pub fn encrypt(mut self, encrypt: bool) -> Self {
138-
self.encrypt = Some(encrypt);
146+
pub fn encrypt(mut self, encrypt: Encrypt) -> Self {
147+
self.encrypt = encrypt;
148+
self
149+
}
150+
151+
pub fn trust_server_certificate(mut self, trust: bool) -> Self {
152+
self.trust_server_certificate = trust;
153+
self
154+
}
155+
156+
pub fn hostname_in_certificate(mut self, hostname: &str) -> Self {
157+
self.hostname_in_certificate = Some(hostname.to_owned());
158+
self
159+
}
160+
161+
/// Sets the name of a file containing SSL certificate authority (CA) certificate(s).
162+
/// If the file exists, the server's certificate will be verified to be signed by
163+
/// one of these authorities.
164+
pub fn ssl_root_cert(mut self, cert: impl AsRef<Path>) -> Self {
165+
self.ssl_root_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf()));
139166
self
140167
}
141168
}

sqlx-core/src/mssql/options/parse.rs

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::error::Error;
2+
use crate::mssql::protocol::pre_login::Encrypt;
23
use crate::mssql::MssqlConnectOptions;
34
use percent_encoding::percent_decode_str;
45
use std::str::FromStr;
@@ -9,9 +10,39 @@ impl FromStr for MssqlConnectOptions {
910

1011
/// Parse a connection string into a set of connection options.
1112
///
12-
/// The connection string is expected to be a valid URL with the following format:
13+
/// The connection string should be a valid URL with the following format:
1314
/// ```text
14-
/// mssql://[username[:password]@]host/database[?instance=instance_name&packet_size=packet_size&client_program_version=client_program_version&client_pid=client_pid&hostname=hostname&app_name=app_name&server_name=server_name&client_interface_name=client_interface_name&language=language&encrypt=true|false]
15+
/// mssql://[username[:password]@]host[:port][/database][?param1=value1&param2=value2...]
16+
/// ```
17+
///
18+
/// Components:
19+
/// - `username`: The username for SQL Server authentication.
20+
/// - `password`: The password for SQL Server authentication.
21+
/// - `host`: The hostname or IP address of the SQL Server.
22+
/// - `port`: The port number (default is 1433).
23+
/// - `database`: The name of the database to connect to.
24+
///
25+
/// Supported query parameters:
26+
/// - `instance`: SQL Server named instance.
27+
/// - `encrypt`: Controls connection encryption:
28+
/// - `strict`: Requires encryption and validates the server certificate.
29+
/// - `mandatory` or `true` or `yes`: Requires encryption but doesn't validate the server certificate.
30+
/// - `optional` or `false` or `no`: Uses encryption if available, falls back to unencrypted.
31+
/// - `sslrootcert` or `ssl-root-cert` or `ssl-ca`: Path to the root certificate for validating the server's SSL certificate.
32+
/// - `trust_server_certificate`: When true, skips validation of the server's SSL certificate. Use with caution as it makes the connection vulnerable to man-in-the-middle attacks.
33+
/// - `hostname_in_certificate`: The hostname expected in the server's SSL certificate. Use this when the server's hostname doesn't match the certificate.
34+
/// - `packet_size`: Size of TDS packets in bytes. Larger sizes can improve performance but consume more memory on the server
35+
/// - `client_program_version`: Version number of the client program, sent to the server for logging purposes.
36+
/// - `client_pid`: Process ID of the client, sent to the server for logging purposes.
37+
/// - `hostname`: Name of the client machine, sent to the server for logging purposes.
38+
/// - `app_name`: Name of the client application, sent to the server for logging purposes.
39+
/// - `server_name`: Name of the server to connect to. Useful when connecting through a proxy or load balancer.
40+
/// - `client_interface_name`: Name of the client interface, sent to the server for logging purposes.
41+
/// - `language`: Sets the language for server messages. Affects date formats and system messages.
42+
///
43+
/// Example:
44+
/// ```text
45+
/// mssql://user:pass@localhost:1433/mydb?encrypt=strict&app_name=MyApp&packet_size=4096
1546
/// ```
1647
fn from_str(s: &str) -> Result<Self, Self::Err> {
1748
let url: Url = s.parse().map_err(Error::config)?;
@@ -52,6 +83,27 @@ impl FromStr for MssqlConnectOptions {
5283
"instance" => {
5384
options = options.instance(&*value);
5485
}
86+
"encrypt" => {
87+
match value.to_lowercase().as_str() {
88+
"strict" => options = options.encrypt(Encrypt::Required),
89+
"mandatory" | "true" | "yes" => options = options.encrypt(Encrypt::On),
90+
"optional" | "false" | "no" => options = options.encrypt(Encrypt::NotSupported),
91+
_ => return Err(Error::config(MssqlInvalidOption(format!(
92+
"encrypt={} is not a valid value for encrypt. Valid values are: strict, mandatory, optional, true, false, yes, no",
93+
value
94+
)))),
95+
}
96+
}
97+
"sslrootcert" | "ssl-root-cert" | "ssl-ca" => {
98+
options = options.ssl_root_cert(&*value);
99+
}
100+
"trust_server_certificate" => {
101+
let trust = value.parse::<bool>().map_err(Error::config)?;
102+
options = options.trust_server_certificate(trust);
103+
}
104+
"hostname_in_certificate" => {
105+
options = options.hostname_in_certificate(&*value);
106+
}
55107
"packet_size" => {
56108
let size = value.parse().map_err(Error::config)?;
57109
options = options.requested_packet_size(size).map_err(|_| {
@@ -67,10 +119,6 @@ impl FromStr for MssqlConnectOptions {
67119
"server_name" => options = options.server_name(&*value),
68120
"client_interface_name" => options = options.client_interface_name(&*value),
69121
"language" => options = options.language(&*value),
70-
"encrypt" => {
71-
let encrypt = value.parse::<bool>().map_err(Error::config)?;
72-
options = options.encrypt(encrypt);
73-
}
74122
_ => {
75123
return Err(Error::config(MssqlInvalidOption(key.into())));
76124
}

sqlx-core/src/mssql/protocol/pre_login.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ impl Display for Version {
277277
/// During the Pre-Login handshake, the client and the server negotiate the
278278
/// wire encryption to be used.
279279
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Hash)]
280-
pub(crate) enum Encrypt {
280+
pub enum Encrypt {
281281
/// Encryption is available but off.
282282
Off = 0x00,
283283

0 commit comments

Comments
 (0)