Skip to content

Commit a18cd3c

Browse files
committed
Fix mssql TDS packet splitting over TLS connections
1 parent 916301f commit a18cd3c

File tree

10 files changed

+140
-91
lines changed

10 files changed

+140
-91
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
- Add sqlx version information to pre-login message in mssql
1111
- Add support for encrypted Microsoft SQL server connections (using TLS)
12+
- Add support for the `SSLKEYLOGFILE` environment variable for TLS decryption in Wireshark
1213

1314
## 0.6.27
1415

sqlx-core/src/mssql/connection/establish.rs

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,18 @@ impl MssqlConnection {
2626

2727
log::debug!("Sending T-SQL PRELOGIN with encryption: {:?}", encryption);
2828

29-
stream.write_packet(
30-
PacketType::PreLogin,
31-
PreLogin {
32-
version: Version::default(),
33-
encryption,
34-
instance: options.instance.clone(),
35-
36-
..Default::default()
37-
},
38-
);
39-
40-
stream.flush().await?;
29+
stream
30+
.write_packet_and_flush(
31+
PacketType::PreLogin,
32+
PreLogin {
33+
version: Version::default(),
34+
encryption,
35+
instance: options.instance.clone(),
36+
37+
..Default::default()
38+
},
39+
)
40+
.await?;
4141

4242
let (_, packet) = stream.recv_packet().await?;
4343
let prelogin_response = PreLogin::decode(packet)?;
@@ -47,31 +47,36 @@ impl MssqlConnection {
4747
Encrypt::Required | Encrypt::On
4848
) {
4949
stream.setup_encryption().await?;
50+
} else if encryption == Encrypt::Required {
51+
return Err(Error::Tls(Box::new(std::io::Error::new(
52+
std::io::ErrorKind::Other,
53+
"TLS encryption required but not supported by server",
54+
))));
5055
}
5156

5257
// LOGIN7 defines the authentication rules for use between client and server
5358

54-
stream.write_packet(
55-
PacketType::Tds7Login,
56-
Login7 {
57-
// FIXME: use a version constant
58-
version: 0x74000004, // SQL Server 2012 - SQL Server 2019
59-
client_program_version: options.client_program_version,
60-
client_pid: options.client_pid,
61-
packet_size: options.requested_packet_size, // max allowed size of TDS packet
62-
hostname: &options.hostname,
63-
username: &options.username,
64-
password: options.password.as_deref().unwrap_or_default(),
65-
app_name: &options.app_name,
66-
server_name: &options.server_name,
67-
client_interface_name: &options.client_interface_name,
68-
language: &options.language,
69-
database: &*options.database,
70-
client_id: [0; 6],
71-
},
72-
);
73-
74-
stream.flush().await?;
59+
stream
60+
.write_packet_and_flush(
61+
PacketType::Tds7Login,
62+
Login7 {
63+
// FIXME: use a version constant
64+
version: 0x74000004, // SQL Server 2012 - SQL Server 2019
65+
client_program_version: options.client_program_version,
66+
client_pid: options.client_pid,
67+
packet_size: options.requested_packet_size, // max allowed size of TDS packet
68+
hostname: &options.hostname,
69+
username: &options.username,
70+
password: options.password.as_deref().unwrap_or_default(),
71+
app_name: &options.app_name,
72+
server_name: &options.server_name,
73+
client_interface_name: &options.client_interface_name,
74+
language: &options.language,
75+
database: &*options.database,
76+
client_id: [0; 6],
77+
},
78+
)
79+
.await?;
7580

7681
loop {
7782
// NOTE: we should receive an [Error] message if something goes wrong, otherwise,

sqlx-core/src/mssql/connection/executor.rs

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,29 @@ impl MssqlConnection {
4141
proc_args.append(&mut arguments);
4242
}
4343

44-
self.stream.write_packet(
45-
PacketType::Rpc,
46-
RpcRequest {
47-
transaction_descriptor: self.stream.transaction_descriptor,
48-
arguments: &proc_args,
49-
procedure: proc,
50-
options: OptionFlags::empty(),
51-
},
52-
);
44+
self.stream
45+
.write_packet_and_flush(
46+
PacketType::Rpc,
47+
RpcRequest {
48+
transaction_descriptor: self.stream.transaction_descriptor,
49+
arguments: &proc_args,
50+
procedure: proc,
51+
options: OptionFlags::empty(),
52+
},
53+
)
54+
.await?;
5355
} else {
54-
self.stream.write_packet(
55-
PacketType::SqlBatch,
56-
SqlBatch {
57-
transaction_descriptor: self.stream.transaction_descriptor,
58-
sql: query,
59-
},
60-
);
56+
self.stream
57+
.write_packet_and_flush(
58+
PacketType::SqlBatch,
59+
SqlBatch {
60+
transaction_descriptor: self.stream.transaction_descriptor,
61+
sql: query,
62+
},
63+
)
64+
.await?;
6165
}
6266

63-
self.stream.flush().await?;
64-
6567
Ok(())
6668
}
6769
}

sqlx-core/src/mssql/connection/prepare.rs

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,22 @@ pub(crate) async fn prepare(
5252
args.add_unnamed(sql);
5353
args.add_unnamed(0x0001_i32); // 1 = SEND_METADATA
5454

55-
conn.stream.write_packet(
56-
PacketType::Rpc,
57-
RpcRequest {
58-
transaction_descriptor: conn.stream.transaction_descriptor,
59-
arguments: &args,
60-
// [sp_prepare] will emit the column meta data
61-
// small issue is that we need to declare all the used placeholders with a "fallback" type
62-
// we currently use regex to collect them; false positives are *okay* but false
63-
// negatives would break the query
64-
procedure: Either::Right(Procedure::Prepare),
65-
options: OptionFlags::empty(),
66-
},
67-
);
68-
69-
conn.stream.flush().await?;
55+
conn.stream
56+
.write_packet_and_flush(
57+
PacketType::Rpc,
58+
RpcRequest {
59+
transaction_descriptor: conn.stream.transaction_descriptor,
60+
arguments: &args,
61+
// [sp_prepare] will emit the column meta data
62+
// small issue is that we need to declare all the used placeholders with a "fallback" type
63+
// we currently use regex to collect them; false positives are *okay* but false
64+
// negatives would break the query
65+
procedure: Either::Right(Procedure::Prepare),
66+
options: OptionFlags::empty(),
67+
},
68+
)
69+
.await?;
70+
7071
conn.stream.wait_until_ready().await?;
7172
conn.stream.pending_done_count += 1;
7273

@@ -100,17 +101,18 @@ pub(crate) async fn prepare(
100101
let mut args = MssqlArguments::default();
101102
args.add_unnamed(id);
102103

103-
conn.stream.write_packet(
104-
PacketType::Rpc,
105-
RpcRequest {
106-
transaction_descriptor: conn.stream.transaction_descriptor,
107-
arguments: &args,
108-
procedure: Either::Right(Procedure::Unprepare),
109-
options: OptionFlags::empty(),
110-
},
111-
);
104+
conn.stream
105+
.write_packet_and_flush(
106+
PacketType::Rpc,
107+
RpcRequest {
108+
transaction_descriptor: conn.stream.transaction_descriptor,
109+
arguments: &args,
110+
procedure: Either::Right(Procedure::Unprepare),
111+
options: OptionFlags::empty(),
112+
},
113+
)
114+
.await?;
112115

113-
conn.stream.flush().await?;
114116
conn.stream.wait_until_ready().await?;
115117
conn.stream.pending_done_count += 1;
116118

sqlx-core/src/mssql/connection/stream.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,39 @@ impl MssqlStream {
7171
})
7272
}
7373

74-
// writes the packet out to the write buffer
74+
// writes the packet out to the write buffer, but does not flush
75+
// WARNING: if the packet is large, and we are over an encrypted connection, this will fail, since we would
76+
// need to flush each packet individually to the encryption layer
7577
pub(crate) fn write_packet<'en, T: Encode<'en>>(&mut self, ty: PacketType, payload: T) {
7678
write_packets(&mut self.inner.wbuf, self.max_packet_size, ty, payload)
7779
}
7880

81+
// writes the packet out to the write buffer, splitting it if neccessary, and flushing TDS packets one at a time
82+
pub(crate) async fn write_packet_and_flush<'en, T: Encode<'en>>(
83+
&mut self,
84+
ty: PacketType,
85+
payload: T,
86+
) -> Result<(), Error> {
87+
if !self.inner.wbuf.is_empty() {
88+
self.flush().await?;
89+
}
90+
self.write_packet(ty, payload);
91+
self.flush().await?;
92+
Ok(())
93+
}
94+
95+
// writes the packet out to the write buffer, splitting it if neccessary, and flushing TDS packets one at a time
96+
pub(crate) async fn flush(&mut self) -> Result<(), Error> {
97+
// flush self.max_packet_size bytes at a time
98+
while self.inner.wbuf.len() > self.max_packet_size {
99+
let rest = self.inner.wbuf.split_off(self.max_packet_size);
100+
self.inner.flush().await?;
101+
self.inner.wbuf = rest;
102+
}
103+
self.inner.flush().await?;
104+
Ok(())
105+
}
106+
79107
// receive the next packet from the database
80108
// blocks until a packet is available
81109
pub(super) async fn recv_packet(&mut self) -> Result<(PacketHeader, Bytes), Error> {

sqlx-core/src/mssql/connection/tls_prelogin_stream_wrapper.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@ impl<S> TlsPreloginWrapper<S> {
6262
}
6363

6464
pub fn start_handshake(&mut self) {
65+
log::trace!("Handshake starting");
6566
self.pending_handshake = true;
6667
}
6768

6869
pub fn handshake_complete(&mut self) {
70+
log::trace!("Handshake complete");
6971
self.pending_handshake = false;
7072
}
7173
}
@@ -159,6 +161,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPreloginWrapper
159161
}
160162

161163
while !inner.wr_buf.is_empty() {
164+
log::trace!("Writing {} bytes of TLS handshake", inner.wr_buf.len());
165+
162166
let written = ready!(Pin::new(&mut inner.stream).poll_write(cx, &inner.wr_buf))?;
163167

164168
inner.wr_buf.drain(..written);

sqlx-core/src/mssql/protocol/packet.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,21 @@ pub(crate) struct PacketHeader {
2727
pub(crate) packet_id: u8,
2828
}
2929

30+
impl PacketHeader {
31+
fn to_array(&self) -> [u8; PACKET_HEADER_SIZE] {
32+
let mut arr = [0u8; PACKET_HEADER_SIZE];
33+
arr[0] = self.r#type as u8;
34+
arr[1] = self.status.bits();
35+
arr[2..4].copy_from_slice(&self.length.to_be_bytes());
36+
arr[4..6].copy_from_slice(&self.server_process_id.to_be_bytes());
37+
arr[6] = self.packet_id;
38+
arr
39+
}
40+
}
41+
3042
impl<'s> Encode<'s, ()> for PacketHeader {
3143
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
32-
buf.push(self.r#type as u8);
33-
buf.push(self.status.bits());
34-
35-
buf.extend(&self.length.to_be_bytes());
36-
37-
buf.extend(&self.server_process_id.to_be_bytes());
38-
buf.push(self.packet_id);
39-
40-
// window, unused
41-
buf.push(0);
44+
buf.extend_from_slice(&self.to_array());
4245
}
4346
}
4447

sqlx-core/src/mssql/transaction.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ impl TransactionManager for MssqlTransactionManager {
8080

8181
conn.stream.pending_done_count += 1;
8282

83+
// We cannot flush since we are in a synchronous context, but the packet will always be small here
8384
conn.stream.write_packet(
8485
PacketType::SqlBatch,
8586
SqlBatch {

sqlx-core/src/net/tls/rustls.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server
22
use rustls::client::WebPkiServerVerifier;
33
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
44
use rustls::{
5-
CertificateError, ClientConfig, DigitallySignedStruct, Error as TlsError, RootCertStore,
6-
SignatureScheme,
5+
CertificateError, ClientConfig, DigitallySignedStruct, Error as TlsError, KeyLogFile,
6+
RootCertStore, SignatureScheme,
77
};
88
use std::io::{BufReader, Cursor};
99
use std::sync::Arc;
@@ -50,7 +50,7 @@ pub async fn configure_tls_connector(
5050
};
5151

5252
// authentication using user's key and its associated certificate
53-
let config = match (tls_config.client_cert_path, tls_config.client_key_path) {
53+
let mut config = match (tls_config.client_cert_path, tls_config.client_key_path) {
5454
(Some(cert_path), Some(key_path)) => {
5555
let cert_chain = certs_from_pem(cert_path.data().await?)?;
5656
let key_der = private_key_from_pem(key_path.data().await?)?;
@@ -66,6 +66,9 @@ pub async fn configure_tls_connector(
6666
}
6767
};
6868

69+
// When SSLKEYLOGFILE is set, write the TLS keys to a file for use with Wireshark
70+
config.key_log = Arc::new(KeyLogFile::new());
71+
6972
Ok(Arc::new(config).into())
7073
}
7174

tests/mssql/mssql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ async fn it_binds_string_with_special_chars() -> anyhow::Result<()> {
205205
async fn it_accepts_long_query_strings() -> anyhow::Result<()> {
206206
let mut conn = new::<Mssql>().await?;
207207
// try a query that does not fit in a single TDS packet
208-
let (n,): (i32,) = sqlx_oldapi::query_as(&format!("SELECT {} 42", " ".repeat(0x1_00_00)))
208+
let (n,): (i32,) = sqlx_oldapi::query_as(&format!("SELECT {} 42", " ".repeat(3000)))
209209
.fetch_one(&mut conn)
210210
.await?;
211211
assert_eq!(n, 42);

0 commit comments

Comments
 (0)