1
1
use itertools:: Itertools ;
2
- use tokio:: net:: lookup_host;
2
+ use thiserror:: Error ;
3
+ use tokio:: net:: { lookup_host, ToSocketAddrs } ;
3
4
use tracing:: warn;
4
5
use uuid:: Uuid ;
5
6
@@ -13,6 +14,7 @@ use crate::transport::errors::{ConnectionPoolError, QueryError};
13
14
use std:: fmt:: Display ;
14
15
use std:: io;
15
16
use std:: net:: IpAddr ;
17
+ use std:: time:: Duration ;
16
18
use std:: {
17
19
hash:: { Hash , Hasher } ,
18
20
net:: SocketAddr ,
@@ -267,27 +269,53 @@ pub(crate) struct ResolvedContactPoint {
267
269
pub ( crate ) datacenter : Option < String > ,
268
270
}
269
271
272
+ #[ derive( Error , Debug ) ]
273
+ pub ( crate ) enum DnsLookupError {
274
+ #[ error( "Failed to perform DNS lookup within {0}ms" ) ]
275
+ Timeout ( u128 ) ,
276
+ #[ error( "Empty address list returned by DNS for {0}" ) ]
277
+ EmptyAddressListForHost ( String ) ,
278
+ #[ error( transparent) ]
279
+ IoError ( #[ from] io:: Error ) ,
280
+ }
281
+
282
+ /// Performs a DNS lookup with provided optional timeout.
283
+ async fn lookup_host_with_timeout < T : ToSocketAddrs > (
284
+ host : T ,
285
+ hostname_resolution_timeout : Option < Duration > ,
286
+ ) -> Result < impl Iterator < Item = SocketAddr > , DnsLookupError > {
287
+ if let Some ( timeout) = hostname_resolution_timeout {
288
+ match tokio:: time:: timeout ( timeout, lookup_host ( host) ) . await {
289
+ Ok ( res) => res. map_err ( Into :: into) ,
290
+ // Elapsed error from tokio library does not provide any context.
291
+ Err ( _) => Err ( DnsLookupError :: Timeout ( timeout. as_millis ( ) ) ) ,
292
+ }
293
+ } else {
294
+ lookup_host ( host) . await . map_err ( Into :: into)
295
+ }
296
+ }
297
+
270
298
// Resolve the given hostname using a DNS lookup if necessary.
271
299
// The resolution may return multiple IPs and the function returns one of them.
272
300
// It prefers to return IPv4s first, and only if there are none, IPv6s.
273
- pub ( crate ) async fn resolve_hostname ( hostname : & str ) -> Result < SocketAddr , io:: Error > {
274
- let addrs = match lookup_host ( hostname) . await {
301
+ pub ( crate ) async fn resolve_hostname (
302
+ hostname : & str ,
303
+ hostname_resolution_timeout : Option < Duration > ,
304
+ ) -> Result < SocketAddr , DnsLookupError > {
305
+ let addrs = match lookup_host_with_timeout ( hostname, hostname_resolution_timeout) . await {
275
306
Ok ( addrs) => itertools:: Either :: Left ( addrs) ,
276
307
// Use a default port in case of error, but propagate the original error on failure
277
308
Err ( e) => {
278
- let addrs = lookup_host ( ( hostname, 9042 ) ) . await . or ( Err ( e) ) ?;
309
+ let addrs = lookup_host_with_timeout ( ( hostname, 9042 ) , hostname_resolution_timeout)
310
+ . await
311
+ . or ( Err ( e) ) ?;
279
312
itertools:: Either :: Right ( addrs)
280
313
}
281
314
} ;
282
315
283
316
addrs
284
317
. find_or_last ( |addr| matches ! ( addr, SocketAddr :: V4 ( _) ) )
285
- . ok_or_else ( || {
286
- io:: Error :: new (
287
- io:: ErrorKind :: Other ,
288
- format ! ( "Empty address list returned by DNS for {}" , hostname) ,
289
- )
290
- } )
318
+ . ok_or_else ( || DnsLookupError :: EmptyAddressListForHost ( hostname. to_owned ( ) ) )
291
319
}
292
320
293
321
/// Transforms the given [`InternalKnownNode`]s into [`ContactPoint`]s.
@@ -296,6 +324,7 @@ pub(crate) async fn resolve_hostname(hostname: &str) -> Result<SocketAddr, io::E
296
324
/// In case of a plain IP address, parses it and uses straight.
297
325
pub ( crate ) async fn resolve_contact_points (
298
326
known_nodes : & [ InternalKnownNode ] ,
327
+ hostname_resolution_timeout : Option < Duration > ,
299
328
) -> ( Vec < ResolvedContactPoint > , Vec < String > ) {
300
329
// Find IP addresses of all known nodes passed in the config
301
330
let mut initial_peers: Vec < ResolvedContactPoint > = Vec :: with_capacity ( known_nodes. len ( ) ) ;
@@ -323,7 +352,7 @@ pub(crate) async fn resolve_contact_points(
323
352
let resolve_futures = to_resolve
324
353
. into_iter ( )
325
354
. map ( |( hostname, datacenter) | async move {
326
- match resolve_hostname ( hostname) . await {
355
+ match resolve_hostname ( hostname, hostname_resolution_timeout ) . await {
327
356
Ok ( address) => Some ( ResolvedContactPoint {
328
357
address,
329
358
datacenter,
0 commit comments