Skip to content

Commit 1f6d9dd

Browse files
committed
Overhaul query cancellation
Multi-host support means we can't simply take the old approach - we need to know which of the hosts we actually connected to. It's also nice to move this from the connection to the client since that's what you'd normally have access to.
1 parent a6535b4 commit 1f6d9dd

File tree

11 files changed

+305
-78
lines changed

11 files changed

+305
-78
lines changed

tokio-postgres/src/config.rs

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ use tokio_io::{AsyncRead, AsyncWrite};
1717

1818
#[cfg(feature = "runtime")]
1919
use crate::proto::ConnectFuture;
20-
use crate::proto::{CancelQueryRawFuture, HandshakeFuture};
21-
use crate::{CancelData, CancelQueryRaw, Error, Handshake, TlsMode};
20+
use crate::proto::HandshakeFuture;
2221
#[cfg(feature = "runtime")]
2322
use crate::{Connect, MakeTlsMode, Socket};
23+
use crate::{Error, Handshake, TlsMode};
2424

2525
#[cfg(feature = "runtime")]
2626
#[derive(Debug, Copy, Clone, PartialEq)]
@@ -267,7 +267,7 @@ impl Config {
267267
S: AsyncRead + AsyncWrite,
268268
T: TlsMode<S>,
269269
{
270-
Handshake(HandshakeFuture::new(stream, tls_mode, self.clone()))
270+
Handshake(HandshakeFuture::new(stream, tls_mode, self.clone(), None))
271271
}
272272

273273
#[cfg(feature = "runtime")]
@@ -277,19 +277,6 @@ impl Config {
277277
{
278278
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
279279
}
280-
281-
pub fn cancel_query_raw<S, T>(
282-
&self,
283-
stream: S,
284-
tls_mode: T,
285-
cancel_data: CancelData,
286-
) -> CancelQueryRaw<S, T>
287-
where
288-
S: AsyncRead + AsyncWrite,
289-
T: TlsMode<S>,
290-
{
291-
CancelQueryRaw(CancelQueryRawFuture::new(stream, tls_mode, cancel_data))
292-
}
293280
}
294281

295282
impl FromStr for Config {

tokio-postgres/src/lib.rs

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,22 @@ impl Client {
206206
BatchExecute(self.0.batch_execute(query))
207207
}
208208

209+
#[cfg(feature = "runtime")]
210+
pub fn cancel_query<T>(&mut self, make_tls_mode: T) -> CancelQuery<T>
211+
where
212+
T: MakeTlsMode<Socket>,
213+
{
214+
CancelQuery(self.0.cancel_query(make_tls_mode))
215+
}
216+
217+
pub fn cancel_query_raw<S, T>(&mut self, stream: S, tls_mode: T) -> CancelQueryRaw<S, T>
218+
where
219+
S: AsyncRead + AsyncWrite,
220+
T: TlsMode<S>,
221+
{
222+
CancelQueryRaw(self.0.cancel_query_raw(stream, tls_mode))
223+
}
224+
209225
pub fn is_closed(&self) -> bool {
210226
self.0.is_closed()
211227
}
@@ -222,10 +238,6 @@ impl<S> Connection<S>
222238
where
223239
S: AsyncRead + AsyncWrite,
224240
{
225-
pub fn cancel_data(&self) -> CancelData {
226-
self.0.cancel_data()
227-
}
228-
229241
pub fn parameter(&self, name: &str) -> Option<&str> {
230242
self.0.parameter(name)
231243
}
@@ -274,6 +286,25 @@ where
274286
}
275287
}
276288

289+
#[cfg(feature = "runtime")]
290+
#[must_use = "futures do nothing unless polled"]
291+
pub struct CancelQuery<T>(proto::CancelQueryFuture<T>)
292+
where
293+
T: MakeTlsMode<Socket>;
294+
295+
#[cfg(feature = "runtime")]
296+
impl<T> Future for CancelQuery<T>
297+
where
298+
T: MakeTlsMode<Socket>,
299+
{
300+
type Item = ();
301+
type Error = Error;
302+
303+
fn poll(&mut self) -> Poll<(), Error> {
304+
self.0.poll()
305+
}
306+
}
307+
277308
#[must_use = "futures do nothing unless polled"]
278309
pub struct Handshake<S, T>(proto::HandshakeFuture<S, T>)
279310
where
@@ -478,15 +509,6 @@ impl Future for BatchExecute {
478509
}
479510
}
480511

481-
/// Contains information necessary to cancel queries for a session.
482-
#[derive(Copy, Clone, Debug)]
483-
pub struct CancelData {
484-
/// The process ID of the session.
485-
pub process_id: i32,
486-
/// The secret key for the session.
487-
pub secret_key: i32,
488-
}
489-
490512
/// An asynchronous notification.
491513
#[derive(Clone, Debug)]
492514
pub struct Notification {
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
use futures::{try_ready, Future, Poll};
2+
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
3+
use std::io;
4+
5+
use crate::proto::{CancelQueryRawFuture, ConnectSocketFuture};
6+
use crate::{Config, Error, Host, MakeTlsMode, Socket};
7+
8+
#[derive(StateMachineFuture)]
9+
pub enum CancelQuery<T>
10+
where
11+
T: MakeTlsMode<Socket>,
12+
{
13+
#[state_machine_future(start, transitions(ConnectingSocket))]
14+
Start {
15+
make_tls_mode: T,
16+
idx: Option<usize>,
17+
config: Config,
18+
process_id: i32,
19+
secret_key: i32,
20+
},
21+
#[state_machine_future(transitions(Canceling))]
22+
ConnectingSocket {
23+
future: ConnectSocketFuture,
24+
tls_mode: T::TlsMode,
25+
process_id: i32,
26+
secret_key: i32,
27+
},
28+
#[state_machine_future(transitions(Finished))]
29+
Canceling {
30+
future: CancelQueryRawFuture<Socket, T::TlsMode>,
31+
},
32+
#[state_machine_future(ready)]
33+
Finished(()),
34+
#[state_machine_future(error)]
35+
Failed(Error),
36+
}
37+
38+
impl<T> PollCancelQuery<T> for CancelQuery<T>
39+
where
40+
T: MakeTlsMode<Socket>,
41+
{
42+
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
43+
let mut state = state.take();
44+
45+
let idx = state.idx.ok_or_else(|| {
46+
Error::connect(io::Error::new(io::ErrorKind::InvalidInput, "unknown host"))
47+
})?;
48+
49+
let hostname = match &state.config.0.host[idx] {
50+
Host::Tcp(host) => &**host,
51+
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
52+
#[cfg(unix)]
53+
Host::Unix(_) => "",
54+
};
55+
let tls_mode = state
56+
.make_tls_mode
57+
.make_tls_mode(hostname)
58+
.map_err(|e| Error::tls(e.into()))?;
59+
60+
transition!(ConnectingSocket {
61+
future: ConnectSocketFuture::new(state.config, idx),
62+
tls_mode,
63+
process_id: state.process_id,
64+
secret_key: state.secret_key,
65+
})
66+
}
67+
68+
fn poll_connecting_socket<'a>(
69+
state: &'a mut RentToOwn<'a, ConnectingSocket<T>>,
70+
) -> Poll<AfterConnectingSocket<T>, Error> {
71+
let socket = try_ready!(state.future.poll());
72+
let state = state.take();
73+
74+
transition!(Canceling {
75+
future: CancelQueryRawFuture::new(
76+
socket,
77+
state.tls_mode,
78+
state.process_id,
79+
state.secret_key
80+
),
81+
})
82+
}
83+
84+
fn poll_canceling<'a>(
85+
state: &'a mut RentToOwn<'a, Canceling<T>>,
86+
) -> Poll<AfterCanceling, Error> {
87+
try_ready!(state.future.poll());
88+
transition!(Finished(()))
89+
}
90+
}
91+
92+
impl<T> CancelQueryFuture<T>
93+
where
94+
T: MakeTlsMode<Socket>,
95+
{
96+
pub fn new(
97+
make_tls_mode: T,
98+
idx: Option<usize>,
99+
config: Config,
100+
process_id: i32,
101+
secret_key: i32,
102+
) -> CancelQueryFuture<T> {
103+
CancelQuery::start(make_tls_mode, idx, config, process_id, secret_key)
104+
}
105+
}

tokio-postgres/src/proto/cancel_query_raw.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
66

77
use crate::error::Error;
88
use crate::proto::TlsFuture;
9-
use crate::{CancelData, TlsMode};
9+
use crate::TlsMode;
1010

1111
#[derive(StateMachineFuture)]
1212
pub enum CancelQueryRaw<S, T>
@@ -17,7 +17,8 @@ where
1717
#[state_machine_future(start, transitions(SendingCancel))]
1818
Start {
1919
future: TlsFuture<S, T>,
20-
cancel_data: CancelData,
20+
process_id: i32,
21+
secret_key: i32,
2122
},
2223
#[state_machine_future(transitions(FlushingCancel))]
2324
SendingCancel {
@@ -40,11 +41,7 @@ where
4041
let (stream, _) = try_ready!(state.future.poll());
4142

4243
let mut buf = vec![];
43-
frontend::cancel_request(
44-
state.cancel_data.process_id,
45-
state.cancel_data.secret_key,
46-
&mut buf,
47-
);
44+
frontend::cancel_request(state.process_id, state.secret_key, &mut buf);
4845

4946
transition!(SendingCancel {
5047
future: io::write_all(stream, buf),
@@ -74,7 +71,12 @@ where
7471
S: AsyncRead + AsyncWrite,
7572
T: TlsMode<S>,
7673
{
77-
pub fn new(stream: S, tls_mode: T, cancel_data: CancelData) -> CancelQueryRawFuture<S, T> {
78-
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), cancel_data)
74+
pub fn new(
75+
stream: S,
76+
tls_mode: T,
77+
process_id: i32,
78+
secret_key: i32,
79+
) -> CancelQueryRawFuture<S, T> {
80+
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), process_id, secret_key)
7981
}
8082
}

tokio-postgres/src/proto/client.rs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use postgres_protocol::message::frontend;
88
use std::collections::HashMap;
99
use std::error::Error as StdError;
1010
use std::sync::{Arc, Weak};
11+
use tokio_io::{AsyncRead, AsyncWrite};
1112

1213
use crate::proto::bind::BindFuture;
1314
use crate::proto::connection::{Request, RequestMessages};
@@ -20,8 +21,13 @@ use crate::proto::prepare::PrepareFuture;
2021
use crate::proto::query::QueryStream;
2122
use crate::proto::simple_query::SimpleQueryStream;
2223
use crate::proto::statement::Statement;
24+
#[cfg(feature = "runtime")]
25+
use crate::proto::CancelQueryFuture;
26+
use crate::proto::CancelQueryRawFuture;
2327
use crate::types::{IsNull, Oid, ToSql, Type};
24-
use crate::Error;
28+
use crate::{Config, Error, TlsMode};
29+
#[cfg(feature = "runtime")]
30+
use crate::{MakeTlsMode, Socket};
2531

2632
pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>);
2733

@@ -44,13 +50,25 @@ struct Inner {
4450
state: Mutex<State>,
4551
idle: IdleState,
4652
sender: mpsc::UnboundedSender<Request>,
53+
process_id: i32,
54+
secret_key: i32,
55+
#[cfg_attr(not(feature = "runtime"), allow(dead_code))]
56+
config: Config,
57+
#[cfg_attr(not(feature = "runtime"), allow(dead_code))]
58+
idx: Option<usize>,
4759
}
4860

4961
#[derive(Clone)]
5062
pub struct Client(Arc<Inner>);
5163

5264
impl Client {
53-
pub fn new(sender: mpsc::UnboundedSender<Request>) -> Client {
65+
pub fn new(
66+
sender: mpsc::UnboundedSender<Request>,
67+
process_id: i32,
68+
secret_key: i32,
69+
config: Config,
70+
idx: Option<usize>,
71+
) -> Client {
5472
Client(Arc::new(Inner {
5573
state: Mutex::new(State {
5674
types: HashMap::new(),
@@ -60,6 +78,10 @@ impl Client {
6078
}),
6179
idle: IdleState::new(),
6280
sender,
81+
process_id,
82+
secret_key,
83+
config,
84+
idx,
6385
}))
6486
}
6587

@@ -222,6 +244,28 @@ impl Client {
222244
self.close(b'P', name)
223245
}
224246

247+
#[cfg(feature = "runtime")]
248+
pub fn cancel_query<T>(&self, make_tls_mode: T) -> CancelQueryFuture<T>
249+
where
250+
T: MakeTlsMode<Socket>,
251+
{
252+
CancelQueryFuture::new(
253+
make_tls_mode,
254+
self.0.idx,
255+
self.0.config.clone(),
256+
self.0.process_id,
257+
self.0.secret_key,
258+
)
259+
}
260+
261+
pub fn cancel_query_raw<S, T>(&self, stream: S, tls_mode: T) -> CancelQueryRawFuture<S, T>
262+
where
263+
S: AsyncRead + AsyncWrite,
264+
T: TlsMode<S>,
265+
{
266+
CancelQueryRawFuture::new(stream, tls_mode, self.0.process_id, self.0.secret_key)
267+
}
268+
225269
fn close(&self, ty: u8, name: &str) {
226270
let mut buf = vec![];
227271
frontend::close(ty, name, &mut buf).expect("statement name not valid");

tokio-postgres/src/proto/connect_once.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ where
6565

6666
transition!(Handshaking {
6767
target_session_attrs: state.config.0.target_session_attrs,
68-
future: HandshakeFuture::new(socket, state.tls_mode, state.config),
68+
future: HandshakeFuture::new(socket, state.tls_mode, state.config, Some(state.idx)),
6969
})
7070
}
7171

0 commit comments

Comments
 (0)