@@ -2,22 +2,33 @@ use base64::prelude::*;
22use http_body_util:: Empty ;
33use hyper:: { body:: Bytes , header} ;
44use hyper_util:: {
5- client:: legacy:: Client ,
5+ client:: legacy:: {
6+ Client ,
7+ connect:: { Connected , Connection } ,
8+ } ,
69 rt:: { TokioExecutor , TokioIo } ,
710} ;
811use std:: {
912 future:: Future ,
13+ io,
1014 pin:: Pin ,
1115 task:: { Context , Poll } ,
1216} ;
13- use tokio:: net:: TcpStream ;
17+ use tokio:: {
18+ io:: { AsyncRead , AsyncWrite , ReadBuf } ,
19+ net:: TcpStream ,
20+ } ;
1421use tonic:: transport:: { Channel , Endpoint } ;
1522use tower:: { Service , service_fn} ;
1623
24+ #[ cfg( unix) ]
25+ use tokio:: net:: UnixStream ;
26+
1727/// Options for HTTP CONNECT proxy.
1828#[ derive( Clone , Debug ) ]
1929pub struct HttpConnectProxyOptions {
20- /// The host:port to proxy through.
30+ /// The host:port to proxy through for TCP, or unix:/path/to/unix.sock for
31+ /// Unix socket (which means it must start with "unix:/").
2132 pub target_addr : String ,
2233 /// Optional HTTP basic auth for the proxy as user/pass tuple.
2334 pub basic_auth : Option < ( String , String ) > ,
@@ -72,7 +83,7 @@ impl HttpConnectProxyOptions {
7283struct OverrideAddrConnector ( String ) ;
7384
7485impl Service < hyper:: Uri > for OverrideAddrConnector {
75- type Response = TokioIo < TcpStream > ;
86+ type Response = TokioIo < ProxyStream > ;
7687
7788 type Error = anyhow:: Error ;
7889
@@ -84,7 +95,115 @@ impl Service<hyper::Uri> for OverrideAddrConnector {
8495
8596 fn call ( & mut self , _uri : hyper:: Uri ) -> Self :: Future {
8697 let target_addr = self . 0 . clone ( ) ;
87- let fut = async move { Ok ( TokioIo :: new ( TcpStream :: connect ( target_addr) . await ?) ) } ;
98+ let fut = async move {
99+ Ok ( TokioIo :: new (
100+ ProxyStream :: connect ( target_addr. as_str ( ) ) . await ?,
101+ ) )
102+ } ;
88103 Box :: pin ( fut)
89104 }
90105}
106+
107+ /// Visible only for tests
108+ #[ doc( hidden) ]
109+ pub enum ProxyStream {
110+ Tcp ( TcpStream ) ,
111+ #[ cfg( unix) ]
112+ Unix ( UnixStream ) ,
113+ }
114+
115+ impl ProxyStream {
116+ async fn connect ( target_addr : & str ) -> anyhow:: Result < Self > {
117+ if target_addr. starts_with ( "unix:/" ) {
118+ #[ cfg( unix) ]
119+ {
120+ Ok ( ProxyStream :: Unix (
121+ UnixStream :: connect ( & target_addr[ 5 ..] ) . await ?,
122+ ) )
123+ }
124+ #[ cfg( not( unix) ) ]
125+ {
126+ Err ( anyhow:: anyhow!(
127+ "Unix sockets are not supported on this platform"
128+ ) )
129+ }
130+ } else {
131+ Ok ( ProxyStream :: Tcp ( TcpStream :: connect ( target_addr) . await ?) )
132+ }
133+ }
134+ }
135+
136+ impl AsyncRead for ProxyStream {
137+ fn poll_read (
138+ self : Pin < & mut Self > ,
139+ cx : & mut Context < ' _ > ,
140+ buf : & mut ReadBuf < ' _ > ,
141+ ) -> Poll < io:: Result < ( ) > > {
142+ match self . get_mut ( ) {
143+ ProxyStream :: Tcp ( s) => Pin :: new ( s) . poll_read ( cx, buf) ,
144+ #[ cfg( unix) ]
145+ ProxyStream :: Unix ( s) => Pin :: new ( s) . poll_read ( cx, buf) ,
146+ }
147+ }
148+ }
149+
150+ impl AsyncWrite for ProxyStream {
151+ fn poll_write (
152+ self : Pin < & mut Self > ,
153+ cx : & mut Context < ' _ > ,
154+ buf : & [ u8 ] ,
155+ ) -> Poll < io:: Result < usize > > {
156+ match self . get_mut ( ) {
157+ ProxyStream :: Tcp ( s) => Pin :: new ( s) . poll_write ( cx, buf) ,
158+ #[ cfg( unix) ]
159+ ProxyStream :: Unix ( s) => Pin :: new ( s) . poll_write ( cx, buf) ,
160+ }
161+ }
162+
163+ fn poll_write_vectored (
164+ self : Pin < & mut Self > ,
165+ cx : & mut Context < ' _ > ,
166+ bufs : & [ io:: IoSlice < ' _ > ] ,
167+ ) -> Poll < io:: Result < usize > > {
168+ match self . get_mut ( ) {
169+ ProxyStream :: Tcp ( s) => Pin :: new ( s) . poll_write_vectored ( cx, bufs) ,
170+ #[ cfg( unix) ]
171+ ProxyStream :: Unix ( s) => Pin :: new ( s) . poll_write_vectored ( cx, bufs) ,
172+ }
173+ }
174+
175+ fn is_write_vectored ( & self ) -> bool {
176+ match self {
177+ ProxyStream :: Tcp ( s) => s. is_write_vectored ( ) ,
178+ #[ cfg( unix) ]
179+ ProxyStream :: Unix ( s) => s. is_write_vectored ( ) ,
180+ }
181+ }
182+
183+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
184+ match self . get_mut ( ) {
185+ ProxyStream :: Tcp ( s) => Pin :: new ( s) . poll_flush ( cx) ,
186+ #[ cfg( unix) ]
187+ ProxyStream :: Unix ( s) => Pin :: new ( s) . poll_flush ( cx) ,
188+ }
189+ }
190+
191+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
192+ match self . get_mut ( ) {
193+ ProxyStream :: Tcp ( s) => Pin :: new ( s) . poll_shutdown ( cx) ,
194+ #[ cfg( unix) ]
195+ ProxyStream :: Unix ( s) => Pin :: new ( s) . poll_shutdown ( cx) ,
196+ }
197+ }
198+ }
199+
200+ impl Connection for ProxyStream {
201+ fn connected ( & self ) -> Connected {
202+ match self {
203+ ProxyStream :: Tcp ( s) => s. connected ( ) ,
204+ // There is no special connected metadata for Unix sockets
205+ #[ cfg( unix) ]
206+ ProxyStream :: Unix ( _) => Connected :: new ( ) ,
207+ }
208+ }
209+ }
0 commit comments