Skip to content

Commit 916301f

Browse files
committed
WIP: add support for mssql + TLS
1 parent ed3b725 commit 916301f

File tree

10 files changed

+300
-30
lines changed

10 files changed

+300
-30
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## 0.6.28
9+
10+
- Add sqlx version information to pre-login message in mssql
11+
- Add support for encrypted Microsoft SQL server connections (using TLS)
12+
813
## 0.6.27
914

1015
- Fix pg i8 decode

Cargo.lock

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlx-core/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ runtime-tokio-rustls = [
9090

9191
# for conditional compilation
9292
_rt-async-std = []
93-
_rt-tokio = ["tokio-stream"]
93+
_rt-tokio = ["tokio-stream", "tokio-util"]
9494
_tls-native-tls = []
9595
_tls-rustls = ["rustls", "rustls-pemfile", "webpki-roots"]
9696

@@ -119,7 +119,7 @@ either = "1.6.1"
119119
futures-channel = { version = "0.3.19", default-features = false, features = ["sink", "alloc", "std"] }
120120
futures-core = { version = "0.3.19", default-features = false }
121121
futures-intrusive = "0.5.0"
122-
futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink"] }
122+
futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] }
123123
# used by the SQLite worker thread to block on the async mutex that locks the database handle
124124
futures-executor = { version = "0.3.19", optional = true }
125125
flume = { version = "0.11.0", optional = true, default-features = false, features = ["async"] }
@@ -154,6 +154,7 @@ sqlformat = "0.2.0"
154154
thiserror = "1.0.30"
155155
time = { version = "0.3.2", features = ["macros", "formatting", "parsing"], optional = true }
156156
tokio-stream = { version = "0.1.8", features = ["fs"], optional = true }
157+
tokio-util = { version = "0.7.0", features = ["compat"], default-features = false, optional = true }
157158
smallvec = "1.7.0"
158159
url = { version = "2.2.2", default-features = false }
159160
uuid = { version = "1.0", default-features = false, optional = true, features = ["std"] }

sqlx-core/src/mssql/connection/establish.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,19 @@ impl MssqlConnection {
1818
// TODO: Encryption
1919
// TODO: Send the version of SQLx over
2020

21+
let encryption = match options.encrypt {
22+
Some(true) => Encrypt::Required,
23+
Some(false) => Encrypt::NotSupported,
24+
None => Encrypt::On,
25+
};
26+
27+
log::debug!("Sending T-SQL PRELOGIN with encryption: {:?}", encryption);
28+
2129
stream.write_packet(
2230
PacketType::PreLogin,
2331
PreLogin {
2432
version: Version::default(),
25-
encryption: Encrypt::NOT_SUPPORTED,
33+
encryption,
2634
instance: options.instance.clone(),
2735

2836
..Default::default()
@@ -32,7 +40,14 @@ impl MssqlConnection {
3240
stream.flush().await?;
3341

3442
let (_, packet) = stream.recv_packet().await?;
35-
let _ = PreLogin::decode(packet)?;
43+
let prelogin_response = PreLogin::decode(packet)?;
44+
45+
if matches!(
46+
prelogin_response.encryption,
47+
Encrypt::Required | Encrypt::On
48+
) {
49+
stream.setup_encryption().await?;
50+
}
3651

3752
// LOGIN7 defines the authentication rules for use between client and server
3853

sqlx-core/src/mssql/connection/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ mod establish;
1515
mod executor;
1616
mod prepare;
1717
mod stream;
18+
mod tls_prelogin_stream_wrapper;
1819

1920
pub struct MssqlConnection {
2021
pub(crate) stream: MssqlStream,

sqlx-core/src/mssql/connection/stream.rs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use sqlx_rt::TcpStream;
66
use crate::error::Error;
77
use crate::ext::ustr::UStr;
88
use crate::io::{BufStream, Encode};
9+
use crate::mssql::connection::tls_prelogin_stream_wrapper::TlsPreloginWrapper;
910
use crate::mssql::protocol::col_meta_data::ColMetaData;
1011
use crate::mssql::protocol::done::{Done, Status as DoneStatus};
1112
use crate::mssql::protocol::env_change::EnvChange;
@@ -19,12 +20,12 @@ use crate::mssql::protocol::return_status::ReturnStatus;
1920
use crate::mssql::protocol::return_value::ReturnValue;
2021
use crate::mssql::protocol::row::Row;
2122
use crate::mssql::{MssqlColumn, MssqlConnectOptions, MssqlDatabaseError};
22-
use crate::net::MaybeTlsStream;
23+
use crate::net::{MaybeTlsStream, TlsConfig};
2324
use crate::HashMap;
2425
use std::sync::Arc;
2526

2627
pub(crate) struct MssqlStream {
27-
inner: BufStream<MaybeTlsStream<TcpStream>>,
28+
inner: BufStream<MaybeTlsStream<TlsPreloginWrapper<TcpStream>>>,
2829

2930
// how many Done (or Error) we are currently waiting for
3031
pub(crate) pending_done_count: usize,
@@ -44,13 +45,15 @@ pub(crate) struct MssqlStream {
4445

4546
// Maximum size of packets to send to the server
4647
pub(crate) max_packet_size: usize,
48+
49+
options: MssqlConnectOptions,
4750
}
4851

4952
impl MssqlStream {
5053
pub(super) async fn connect(options: &MssqlConnectOptions) -> Result<Self, Error> {
51-
let inner = BufStream::new(MaybeTlsStream::Raw(
52-
TcpStream::connect((&*options.host, options.port)).await?,
53-
));
54+
let tcp_stream = TcpStream::connect((&*options.host, options.port)).await?;
55+
let wrapped_stream = TlsPreloginWrapper::new(tcp_stream);
56+
let inner = BufStream::new(MaybeTlsStream::Raw(wrapped_stream));
5457

5558
Ok(Self {
5659
inner,
@@ -64,6 +67,7 @@ impl MssqlStream {
6467
.requested_packet_size
6568
.try_into()
6669
.unwrap_or(usize::MAX),
70+
options: options.clone(),
6771
})
6872
}
6973

@@ -206,10 +210,25 @@ impl MssqlStream {
206210

207211
Ok(())
208212
}
213+
214+
pub(crate) async fn setup_encryption(&mut self) -> Result<(), Error> {
215+
let tls_config = TlsConfig {
216+
accept_invalid_certs: true,
217+
hostname: &self.options.host,
218+
accept_invalid_hostnames: true,
219+
root_cert_path: None,
220+
client_cert_path: None,
221+
client_key_path: None,
222+
};
223+
self.inner.deref_mut().start_handshake();
224+
self.inner.upgrade(tls_config).await?;
225+
self.inner.deref_mut().handshake_complete();
226+
Ok(())
227+
}
209228
}
210229

211230
// writes the packet out to the write buffer
212-
fn write_packets<'en, T: Encode<'en>>(
231+
pub(crate) fn write_packets<'en, T: Encode<'en>>(
213232
buffer: &mut Vec<u8>,
214233
max_packet_size: usize,
215234
ty: PacketType,
@@ -313,7 +332,7 @@ fn test_write_packets() {
313332
}
314333

315334
impl Deref for MssqlStream {
316-
type Target = BufStream<MaybeTlsStream<TcpStream>>;
335+
type Target = BufStream<MaybeTlsStream<TlsPreloginWrapper<TcpStream>>>;
317336

318337
fn deref(&self) -> &Self::Target {
319338
&self.inner
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
// Original implementation from tiberius: https://github.com/prisma/tiberius/blob/main/src/client/tls.rs
2+
3+
use crate::mssql::protocol::packet::{PacketHeader, PacketType};
4+
5+
use super::stream::write_packets;
6+
7+
use crate::io::Decode;
8+
use bytes::Bytes;
9+
use sqlx_rt::{AsyncRead, AsyncWrite, ReadBuf};
10+
use std::cmp;
11+
use std::io;
12+
use std::pin::Pin;
13+
use std::task::{self, ready, Poll};
14+
15+
/// This wrapper handles TDS (Tabular Data Stream) packet encapsulation during the TLS handshake phase
16+
/// of a connection to a Microsoft SQL Server.
17+
///
18+
/// In the PRELOGIN phase of the TDS protocol, all communication must be wrapped in TDS packets,
19+
/// even during TLS negotiation. This presents a challenge when using standard TLS libraries,
20+
/// which expect to work with raw TCP streams.
21+
///
22+
/// This wrapper solves the problem by:
23+
/// 1. During handshake:
24+
/// - For writes: It buffers outgoing data and wraps it in TDS packets before sending.
25+
/// Each packet starts with an 8-byte header containing type (0x12 for PRELOGIN),
26+
/// status flags, length, and other metadata.
27+
/// - For reads: It strips the TDS packet headers from incoming data before passing
28+
/// it to the TLS library.
29+
/// 2. After handshake:
30+
/// - It becomes transparent, directly passing through all reads and writes to the
31+
/// underlying stream without modification.
32+
///
33+
/// This allows us to use standard TLS libraries while still conforming to the TDS protocol
34+
/// requirements for the PRELOGIN phase.
35+
36+
const HEADER_BYTES: usize = 8;
37+
38+
pub(crate) struct TlsPreloginWrapper<S> {
39+
stream: S,
40+
pending_handshake: bool,
41+
42+
header_buf: [u8; HEADER_BYTES],
43+
header_pos: usize,
44+
read_remaining: usize,
45+
46+
wr_buf: Vec<u8>,
47+
header_written: bool,
48+
}
49+
50+
impl<S> TlsPreloginWrapper<S> {
51+
pub fn new(stream: S) -> Self {
52+
TlsPreloginWrapper {
53+
stream,
54+
pending_handshake: false,
55+
56+
header_buf: [0u8; HEADER_BYTES],
57+
header_pos: 0,
58+
read_remaining: 0,
59+
wr_buf: Vec::new(),
60+
header_written: false,
61+
}
62+
}
63+
64+
pub fn start_handshake(&mut self) {
65+
self.pending_handshake = true;
66+
}
67+
68+
pub fn handshake_complete(&mut self) {
69+
self.pending_handshake = false;
70+
}
71+
}
72+
73+
impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<S> {
74+
fn poll_read(
75+
mut self: Pin<&mut Self>,
76+
cx: &mut task::Context<'_>,
77+
buf: &mut ReadBuf<'_>,
78+
) -> Poll<io::Result<()>> {
79+
if !self.pending_handshake {
80+
return Pin::new(&mut self.stream).poll_read(cx, buf);
81+
}
82+
83+
let inner = self.get_mut();
84+
85+
if !inner.header_buf[inner.header_pos..].is_empty() {
86+
while !inner.header_buf[inner.header_pos..].is_empty() {
87+
let mut header_buf = ReadBuf::new(&mut inner.header_buf[inner.header_pos..]);
88+
ready!(Pin::new(&mut inner.stream).poll_read(cx, &mut header_buf))?;
89+
90+
let read = header_buf.filled().len();
91+
if read == 0 {
92+
return Poll::Ready(Ok(()));
93+
}
94+
95+
inner.header_pos += read;
96+
}
97+
98+
let header: PacketHeader = Decode::decode(Bytes::copy_from_slice(&inner.header_buf))
99+
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
100+
101+
inner.read_remaining = usize::from(header.length) - HEADER_BYTES;
102+
103+
log::trace!(
104+
"Discarding header ({:?}), reading packet of {} bytes",
105+
header,
106+
inner.read_remaining,
107+
);
108+
}
109+
110+
let max_read = std::cmp::min(inner.read_remaining, buf.remaining());
111+
let mut limited_buf = buf.take(max_read);
112+
113+
ready!(Pin::new(&mut inner.stream).poll_read(cx, &mut limited_buf))?;
114+
115+
let read = limited_buf.filled().len();
116+
buf.advance(read);
117+
inner.read_remaining -= read;
118+
119+
if inner.read_remaining == 0 {
120+
inner.header_pos = 0;
121+
}
122+
123+
Poll::Ready(Ok(()))
124+
}
125+
}
126+
127+
impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPreloginWrapper<S> {
128+
fn poll_write(
129+
mut self: Pin<&mut Self>,
130+
cx: &mut task::Context<'_>,
131+
buf: &[u8],
132+
) -> Poll<io::Result<usize>> {
133+
// Normal operation does not need any extra treatment, we handle
134+
// packets in the codec.
135+
if !self.pending_handshake {
136+
return Pin::new(&mut self.stream).poll_write(cx, buf);
137+
}
138+
139+
// Buffering data.
140+
self.wr_buf.extend_from_slice(buf);
141+
142+
Poll::Ready(Ok(buf.len()))
143+
}
144+
145+
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
146+
let inner = self.get_mut();
147+
148+
// If on handshake mode, wraps the data to a TDS packet before sending.
149+
if inner.pending_handshake {
150+
if !inner.header_written {
151+
let buf = std::mem::take(&mut inner.wr_buf);
152+
write_packets(
153+
&mut inner.wr_buf,
154+
4096,
155+
PacketType::PreLogin,
156+
buf.as_slice(),
157+
);
158+
inner.header_written = true;
159+
}
160+
161+
while !inner.wr_buf.is_empty() {
162+
let written = ready!(Pin::new(&mut inner.stream).poll_write(cx, &inner.wr_buf))?;
163+
164+
inner.wr_buf.drain(..written);
165+
}
166+
167+
inner.header_written = false;
168+
}
169+
170+
Pin::new(&mut inner.stream).poll_flush(cx)
171+
}
172+
173+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
174+
Pin::new(&mut self.stream).poll_shutdown(cx)
175+
}
176+
}

0 commit comments

Comments
 (0)