Skip to content

Commit eb74c70

Browse files
authored
Support for Unix sockets in HTTP connect proxy (#984)
Fixes #981
1 parent 6899691 commit eb74c70

File tree

6 files changed

+355
-15
lines changed

6 files changed

+355
-15
lines changed

client/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ extern crate tracing;
99

1010
pub mod callback_based;
1111
mod metrics;
12-
mod proxy;
12+
/// Visible only for tests
13+
#[doc(hidden)]
14+
pub mod proxy;
1315
mod raw;
1416
mod retry;
1517
mod worker_registry;

client/src/proxy.rs

Lines changed: 124 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,33 @@ use base64::prelude::*;
22
use http_body_util::Empty;
33
use hyper::{body::Bytes, header};
44
use hyper_util::{
5-
client::legacy::Client,
5+
client::legacy::{
6+
Client,
7+
connect::{Connected, Connection},
8+
},
69
rt::{TokioExecutor, TokioIo},
710
};
811
use std::{
912
future::Future,
13+
io,
1014
pin::Pin,
1115
task::{Context, Poll},
1216
};
13-
use tokio::net::TcpStream;
17+
use tokio::{
18+
io::{AsyncRead, AsyncWrite, ReadBuf},
19+
net::TcpStream,
20+
};
1421
use tonic::transport::{Channel, Endpoint};
1522
use tower::{Service, service_fn};
1623

24+
#[cfg(unix)]
25+
use tokio::net::UnixStream;
26+
1727
/// Options for HTTP CONNECT proxy.
1828
#[derive(Clone, Debug)]
1929
pub struct HttpConnectProxyOptions {
20-
/// The host:port to proxy through.
30+
/// The host:port to proxy through for TCP, or unix:/path/to/unix.sock for
31+
/// Unix socket (which means it must start with "unix:/").
2132
pub target_addr: String,
2233
/// Optional HTTP basic auth for the proxy as user/pass tuple.
2334
pub basic_auth: Option<(String, String)>,
@@ -72,7 +83,7 @@ impl HttpConnectProxyOptions {
7283
struct OverrideAddrConnector(String);
7384

7485
impl Service<hyper::Uri> for OverrideAddrConnector {
75-
type Response = TokioIo<TcpStream>;
86+
type Response = TokioIo<ProxyStream>;
7687

7788
type Error = anyhow::Error;
7889

@@ -84,7 +95,115 @@ impl Service<hyper::Uri> for OverrideAddrConnector {
8495

8596
fn call(&mut self, _uri: hyper::Uri) -> Self::Future {
8697
let target_addr = self.0.clone();
87-
let fut = async move { Ok(TokioIo::new(TcpStream::connect(target_addr).await?)) };
98+
let fut = async move {
99+
Ok(TokioIo::new(
100+
ProxyStream::connect(target_addr.as_str()).await?,
101+
))
102+
};
88103
Box::pin(fut)
89104
}
90105
}
106+
107+
/// Visible only for tests
108+
#[doc(hidden)]
109+
pub enum ProxyStream {
110+
Tcp(TcpStream),
111+
#[cfg(unix)]
112+
Unix(UnixStream),
113+
}
114+
115+
impl ProxyStream {
116+
async fn connect(target_addr: &str) -> anyhow::Result<Self> {
117+
if target_addr.starts_with("unix:/") {
118+
#[cfg(unix)]
119+
{
120+
Ok(ProxyStream::Unix(
121+
UnixStream::connect(&target_addr[5..]).await?,
122+
))
123+
}
124+
#[cfg(not(unix))]
125+
{
126+
Err(anyhow::anyhow!(
127+
"Unix sockets are not supported on this platform"
128+
))
129+
}
130+
} else {
131+
Ok(ProxyStream::Tcp(TcpStream::connect(target_addr).await?))
132+
}
133+
}
134+
}
135+
136+
impl AsyncRead for ProxyStream {
137+
fn poll_read(
138+
self: Pin<&mut Self>,
139+
cx: &mut Context<'_>,
140+
buf: &mut ReadBuf<'_>,
141+
) -> Poll<io::Result<()>> {
142+
match self.get_mut() {
143+
ProxyStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
144+
#[cfg(unix)]
145+
ProxyStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
146+
}
147+
}
148+
}
149+
150+
impl AsyncWrite for ProxyStream {
151+
fn poll_write(
152+
self: Pin<&mut Self>,
153+
cx: &mut Context<'_>,
154+
buf: &[u8],
155+
) -> Poll<io::Result<usize>> {
156+
match self.get_mut() {
157+
ProxyStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
158+
#[cfg(unix)]
159+
ProxyStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
160+
}
161+
}
162+
163+
fn poll_write_vectored(
164+
self: Pin<&mut Self>,
165+
cx: &mut Context<'_>,
166+
bufs: &[io::IoSlice<'_>],
167+
) -> Poll<io::Result<usize>> {
168+
match self.get_mut() {
169+
ProxyStream::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
170+
#[cfg(unix)]
171+
ProxyStream::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
172+
}
173+
}
174+
175+
fn is_write_vectored(&self) -> bool {
176+
match self {
177+
ProxyStream::Tcp(s) => s.is_write_vectored(),
178+
#[cfg(unix)]
179+
ProxyStream::Unix(s) => s.is_write_vectored(),
180+
}
181+
}
182+
183+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184+
match self.get_mut() {
185+
ProxyStream::Tcp(s) => Pin::new(s).poll_flush(cx),
186+
#[cfg(unix)]
187+
ProxyStream::Unix(s) => Pin::new(s).poll_flush(cx),
188+
}
189+
}
190+
191+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192+
match self.get_mut() {
193+
ProxyStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
194+
#[cfg(unix)]
195+
ProxyStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
196+
}
197+
}
198+
}
199+
200+
impl Connection for ProxyStream {
201+
fn connected(&self) -> Connected {
202+
match self {
203+
ProxyStream::Tcp(s) => s.connected(),
204+
// There is no special connected metadata for Unix sockets
205+
#[cfg(unix)]
206+
ProxyStream::Unix(_) => Connected::new(),
207+
}
208+
}
209+
}

test-utils/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ ephemeral-server = ["temporal-sdk-core/ephemeral-server"]
1717
anyhow = "1.0"
1818
assert_matches = "1"
1919
async-trait = "0.1"
20+
bytes = "1.10"
2021
futures-util = { version = "0.3", default-features = false }
22+
hyper = { version = "1.4.1" }
23+
http-body-util = "0.1"
24+
hyper-util = "0.1.6"
2125
parking_lot = "0.12"
2226
prost = { workspace = true }
2327
rand = "0.9"

test-utils/src/http_proxy.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
use std::{
2+
io,
3+
sync::{
4+
Arc,
5+
atomic::{AtomicUsize, Ordering},
6+
},
7+
};
8+
9+
use bytes::Bytes;
10+
use http_body_util::Empty;
11+
use hyper::{
12+
Request, Response, StatusCode, body::Incoming, server::conn::http1, service::service_fn,
13+
};
14+
use hyper_util::rt::TokioIo;
15+
use temporal_client::proxy::ProxyStream;
16+
#[cfg(unix)]
17+
use tokio::net::UnixListener;
18+
use tokio::{
19+
net::{TcpListener, TcpStream},
20+
sync::oneshot,
21+
};
22+
23+
pub struct HttpProxy {
24+
proxy_hits: Arc<AtomicUsize>,
25+
shutdown_tx: oneshot::Sender<()>,
26+
}
27+
impl HttpProxy {
28+
pub fn spawn_tcp(listener: TcpListener) -> Self {
29+
Self::spawn(ProxyListener::Tcp(listener))
30+
}
31+
32+
#[cfg(unix)]
33+
pub fn spawn_unix(listener: UnixListener) -> Self {
34+
Self::spawn(ProxyListener::Unix(listener))
35+
}
36+
37+
fn spawn(listener: ProxyListener) -> Self {
38+
let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
39+
let proxy_hits = Arc::new(AtomicUsize::new(0));
40+
let proxy_hits_cloned = proxy_hits.clone();
41+
tokio::spawn(async move {
42+
loop {
43+
let proxy_hits_cloned = proxy_hits_cloned.clone();
44+
tokio::select! {
45+
_ = &mut shutdown_rx => break,
46+
stream = listener.accept() => {
47+
let stream = match stream {
48+
Ok(stream) => stream,
49+
Err(e) => { println!("Proxy accept error: {e}"); continue; }
50+
};
51+
tokio::spawn(async move {
52+
if let Err(e) = http1::Builder::new()
53+
.serve_connection(
54+
TokioIo::new(stream),
55+
service_fn(move |req| handle_connect(req, proxy_hits_cloned.clone())),
56+
)
57+
.with_upgrades()
58+
.await
59+
{
60+
println!("Proxy conn error: {e}");
61+
}
62+
});
63+
}
64+
}
65+
}
66+
});
67+
Self {
68+
proxy_hits,
69+
shutdown_tx,
70+
}
71+
}
72+
73+
pub fn hit_count(&self) -> usize {
74+
self.proxy_hits.load(Ordering::SeqCst)
75+
}
76+
77+
/// Returns before shutdown occurs
78+
pub fn shutdown(self) {
79+
let _ = self.shutdown_tx.send(());
80+
}
81+
}
82+
83+
async fn handle_connect(
84+
req: Request<Incoming>,
85+
counter: Arc<AtomicUsize>,
86+
) -> Result<Response<Empty<Bytes>>, hyper::Error> {
87+
if req.method() == hyper::Method::CONNECT {
88+
// Increment atomic counter
89+
counter.fetch_add(1, Ordering::SeqCst);
90+
91+
// Tell the client the tunnel is established
92+
tokio::spawn(async move {
93+
if let Some(addr) = req.uri().authority().map(|a| a.as_str()) {
94+
match TcpStream::connect(addr).await {
95+
Ok(mut server_stream) => match hyper::upgrade::on(req).await {
96+
Ok(upgraded) => {
97+
let mut upgraded = TokioIo::new(upgraded);
98+
let _ =
99+
tokio::io::copy_bidirectional(&mut upgraded, &mut server_stream)
100+
.await;
101+
}
102+
Err(err) => println!("Upgrade failed: {err}"),
103+
},
104+
Err(e) => println!("Failed to connect to {addr}: {e}"),
105+
}
106+
}
107+
});
108+
109+
Ok(Response::builder()
110+
.status(StatusCode::OK)
111+
.body(Empty::new())
112+
.unwrap())
113+
} else {
114+
Ok(Response::builder()
115+
.status(StatusCode::METHOD_NOT_ALLOWED)
116+
.body(Empty::new())
117+
.unwrap())
118+
}
119+
}
120+
121+
enum ProxyListener {
122+
Tcp(TcpListener),
123+
#[cfg(unix)]
124+
Unix(UnixListener),
125+
}
126+
127+
impl ProxyListener {
128+
async fn accept(&self) -> io::Result<ProxyStream> {
129+
match self {
130+
ProxyListener::Tcp(tcp) => tcp.accept().await.map(|(s, _)| ProxyStream::Tcp(s)),
131+
#[cfg(unix)]
132+
ProxyListener::Unix(unix) => unix.accept().await.map(|(s, _)| ProxyStream::Unix(s)),
133+
}
134+
}
135+
}

test-utils/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
extern crate tracing;
66

77
pub mod canned_histories;
8+
mod http_proxy;
89
pub mod interceptors;
910
pub mod workflows;
1011

12+
pub use http_proxy::HttpProxy;
1113
pub use temporal_sdk_core::replay::HistoryForReplay;
1214

1315
use crate::stream::{Stream, TryStreamExt};

0 commit comments

Comments
 (0)