Skip to content

Commit b393228

Browse files
committed
Fix compilation for async-std
1 parent 77d5a64 commit b393228

File tree

2 files changed

+53
-5
lines changed

2 files changed

+53
-5
lines changed

sqlx-core/src/mssql/connection/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ impl Debug for MssqlConnection {
2929
}
3030
}
3131

32+
use std::ops::DerefMut;
33+
3234
impl Connection for MssqlConnection {
3335
type Database = Mssql;
3436

sqlx-core/src/mssql/connection/tls_prelogin_stream_wrapper.rs

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ use super::stream::write_packets;
66

77
use crate::io::Decode;
88
use bytes::Bytes;
9-
use sqlx_rt::{AsyncRead, AsyncWrite, ReadBuf};
9+
use sqlx_rt::{AsyncRead, AsyncWrite};
10+
11+
#[cfg(feature = "_rt-tokio")]
12+
use sqlx_rt::ReadBuf;
13+
1014
use std::cmp;
1115
use std::io;
1216
use std::pin::Pin;
@@ -72,12 +76,19 @@ impl<S> TlsPreloginWrapper<S> {
7276
}
7377
}
7478

79+
#[cfg(feature = "_rt-async-std")]
80+
type PollReadOut = usize;
81+
82+
#[cfg(feature = "_rt-tokio")]
83+
type PollReadOut = ();
84+
7585
impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<S> {
86+
#[cfg(feature = "_rt-tokio")]
7687
fn poll_read(
7788
mut self: Pin<&mut Self>,
7889
cx: &mut task::Context<'_>,
7990
buf: &mut ReadBuf<'_>,
80-
) -> Poll<io::Result<()>> {
91+
) -> Poll<io::Result<PollReadOut>> {
8192
if !self.pending_handshake {
8293
return Pin::new(&mut self.stream).poll_read(cx, buf);
8394
}
@@ -91,7 +102,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<
91102

92103
let read = header_buf.filled().len();
93104
if read == 0 {
94-
return Poll::Ready(Ok(()));
105+
return Poll::Ready(Ok(PollReadOut::default()));
95106
}
96107

97108
inner.header_pos += read;
@@ -112,7 +123,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<
112123
let max_read = std::cmp::min(inner.read_remaining, buf.remaining());
113124
let mut limited_buf = buf.take(max_read);
114125

115-
ready!(Pin::new(&mut inner.stream).poll_read(cx, &mut limited_buf))?;
126+
let res = ready!(Pin::new(&mut inner.stream).poll_read(cx, &mut limited_buf))?;
116127

117128
let read = limited_buf.filled().len();
118129
buf.advance(read);
@@ -122,7 +133,20 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<
122133
inner.header_pos = 0;
123134
}
124135

125-
Poll::Ready(Ok(()))
136+
Poll::Ready(Ok(res))
137+
}
138+
139+
#[cfg(feature = "_rt-async-std")]
140+
fn poll_read(
141+
mut self: Pin<&mut Self>,
142+
cx: &mut task::Context<'_>,
143+
buf: &mut [u8],
144+
) -> Poll<io::Result<usize>> {
145+
if self.pending_handshake {
146+
panic!("TLS not supported on async-std for mssql");
147+
}
148+
149+
Pin::new(&mut self.stream).poll_read(cx, buf)
126150
}
127151
}
128152

@@ -174,7 +198,29 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPreloginWrapper
174198
Pin::new(&mut inner.stream).poll_flush(cx)
175199
}
176200

201+
#[cfg(feature = "_rt-tokio")]
177202
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
178203
Pin::new(&mut self.stream).poll_shutdown(cx)
179204
}
205+
206+
#[cfg(feature = "_rt-async-std")]
207+
fn poll_close(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
208+
Pin::new(&mut self.stream).poll_close(cx)
209+
}
210+
}
211+
212+
use std::ops::{Deref, DerefMut};
213+
214+
impl<S> Deref for TlsPreloginWrapper<S> {
215+
type Target = S;
216+
217+
fn deref(&self) -> &Self::Target {
218+
&self.stream
219+
}
220+
}
221+
222+
impl<S> DerefMut for TlsPreloginWrapper<S> {
223+
fn deref_mut(&mut self) -> &mut Self::Target {
224+
&mut self.stream
225+
}
180226
}

0 commit comments

Comments
 (0)