@@ -24,20 +24,18 @@ pub type SharedFutureResult<T> = Shared<BoxFuture<'static, Result<Arc<T>, Arc<an
2424
2525#[ derive( Default ) ]
2626pub struct OutboundNetworkingFactor {
27- disallowed_host_callback : Option < DisallowedHostCallback > ,
27+ disallowed_host_handler : Option < Arc < dyn DisallowedHostHandler > > ,
2828}
2929
30- pub type DisallowedHostCallback = fn ( scheme : & str , authority : & str ) ;
31-
3230impl OutboundNetworkingFactor {
3331 pub fn new ( ) -> Self {
3432 Self :: default ( )
3533 }
3634
37- /// Sets a function to be called when a request is disallowed by an
35+ /// Sets a handler to be called when a request is disallowed by an
3836 /// instance's configured `allowed_outbound_hosts`.
39- pub fn set_disallowed_host_callback ( & mut self , callback : DisallowedHostCallback ) {
40- self . disallowed_host_callback = Some ( callback ) ;
37+ pub fn set_disallowed_host_handler ( & mut self , handler : impl DisallowedHostHandler + ' static ) {
38+ self . disallowed_host_handler = Some ( Arc :: new ( handler ) ) ;
4139 }
4240}
4341
@@ -106,7 +104,7 @@ impl Factor for OutboundNetworkingFactor {
106104 // Update Wasi socket allowed ports
107105 let allowed_hosts = OutboundAllowedHosts {
108106 allowed_hosts_future : allowed_hosts_future. clone ( ) ,
109- disallowed_host_callback : self . disallowed_host_callback ,
107+ disallowed_host_handler : self . disallowed_host_handler . clone ( ) ,
110108 } ;
111109 wasi_builder. outbound_socket_addr_check ( move |addr, addr_use| {
112110 let allowed_hosts = allowed_hosts. clone ( ) ;
@@ -137,7 +135,7 @@ impl Factor for OutboundNetworkingFactor {
137135 Ok ( InstanceBuilder {
138136 allowed_hosts_future,
139137 component_tls_configs,
140- disallowed_host_callback : self . disallowed_host_callback ,
138+ disallowed_host_handler : self . disallowed_host_handler . clone ( ) ,
141139 } )
142140 }
143141}
@@ -150,14 +148,14 @@ pub struct AppState {
150148pub struct InstanceBuilder {
151149 allowed_hosts_future : SharedFutureResult < AllowedHostsConfig > ,
152150 component_tls_configs : ComponentTlsConfigs ,
153- disallowed_host_callback : Option < DisallowedHostCallback > ,
151+ disallowed_host_handler : Option < Arc < dyn DisallowedHostHandler > > ,
154152}
155153
156154impl InstanceBuilder {
157155 pub fn allowed_hosts ( & self ) -> OutboundAllowedHosts {
158156 OutboundAllowedHosts {
159157 allowed_hosts_future : self . allowed_hosts_future . clone ( ) ,
160- disallowed_host_callback : self . disallowed_host_callback ,
158+ disallowed_host_handler : self . disallowed_host_handler . clone ( ) ,
161159 }
162160 }
163161
@@ -178,33 +176,10 @@ impl FactorInstanceBuilder for InstanceBuilder {
178176#[ derive( Clone ) ]
179177pub struct OutboundAllowedHosts {
180178 allowed_hosts_future : SharedFutureResult < AllowedHostsConfig > ,
181- disallowed_host_callback : Option < DisallowedHostCallback > ,
179+ disallowed_host_handler : Option < Arc < dyn DisallowedHostHandler > > ,
182180}
183181
184182impl OutboundAllowedHosts {
185- pub async fn resolve ( & self ) -> anyhow:: Result < Arc < AllowedHostsConfig > > {
186- self . allowed_hosts_future . clone ( ) . await . map_err ( |err| {
187- // TODO: better way to handle this?
188- anyhow:: Error :: msg ( err)
189- } )
190- }
191-
192- /// Checks if the given URL is allowed by this component's
193- /// `allowed_outbound_hosts`.
194- pub async fn allows ( & self , url : & OutboundUrl ) -> anyhow:: Result < bool > {
195- Ok ( self . resolve ( ) . await ?. allows ( url) )
196- }
197-
198- /// Report that an outbound connection has been disallowed by e.g.
199- /// [`OutboundAllowedHosts::allows`] returning `false`.
200- ///
201- /// Calls the [`DisallowedHostCallback`] if set.
202- pub fn report_disallowed_host ( & self , scheme : & str , authority : & str ) {
203- if let Some ( disallowed_host_callback) = self . disallowed_host_callback {
204- disallowed_host_callback ( scheme, authority) ;
205- }
206- }
207-
208183 /// Checks address against allowed hosts
209184 ///
210185 /// Calls the [`DisallowedHostCallback`] if set and URL is disallowed.
@@ -217,11 +192,47 @@ impl OutboundAllowedHosts {
217192 } ;
218193
219194 let allowed_hosts = self . resolve ( ) . await ?;
220-
221195 let is_allowed = allowed_hosts. allows ( & url) ;
222196 if !is_allowed {
223197 self . report_disallowed_host ( url. scheme ( ) , & url. authority ( ) ) ;
224198 }
225199 Ok ( is_allowed)
226200 }
201+
202+ /// Checks if allowed hosts permit relative requests
203+ ///
204+ /// Calls the [`DisallowedHostCallback`] if set and relative requests are
205+ /// disallowed.
206+ pub async fn check_relative_url ( & self , schemes : & [ & str ] ) -> anyhow:: Result < bool > {
207+ let allowed_hosts = self . resolve ( ) . await ?;
208+ let is_allowed = allowed_hosts. allows_relative_url ( schemes) ;
209+ if !is_allowed {
210+ let scheme = schemes. first ( ) . unwrap_or ( & "" ) ;
211+ self . report_disallowed_host ( scheme, "self" ) ;
212+ }
213+ Ok ( is_allowed)
214+ }
215+
216+ async fn resolve ( & self ) -> anyhow:: Result < Arc < AllowedHostsConfig > > {
217+ self . allowed_hosts_future . clone ( ) . await . map_err ( |err| {
218+ tracing:: error!( "Error resolving allowed_outbound_hosts variables: {err}" ) ;
219+ anyhow:: Error :: msg ( err)
220+ } )
221+ }
222+
223+ fn report_disallowed_host ( & self , scheme : & str , authority : & str ) {
224+ if let Some ( handler) = & self . disallowed_host_handler {
225+ handler. handle_disallowed_host ( scheme, authority) ;
226+ }
227+ }
228+ }
229+
230+ pub trait DisallowedHostHandler : Send + Sync {
231+ fn handle_disallowed_host ( & self , scheme : & str , authority : & str ) ;
232+ }
233+
234+ impl < F : Fn ( & str , & str ) + Send + Sync > DisallowedHostHandler for F {
235+ fn handle_disallowed_host ( & self , scheme : & str , authority : & str ) {
236+ self ( scheme, authority) ;
237+ }
227238}
0 commit comments