1
1
//! Easy file downloading
2
2
3
3
use std:: fs:: remove_file;
4
+ use std:: num:: NonZeroU64 ;
4
5
use std:: path:: Path ;
6
+ use std:: str:: FromStr ;
7
+ use std:: time:: Duration ;
5
8
6
9
use anyhow:: Context ;
7
10
#[ cfg( any(
@@ -194,6 +197,13 @@ async fn download_file_(
194
197
_ => Backend :: Curl ,
195
198
} ;
196
199
200
+ let timeout = Duration :: from_secs ( match process. var ( "RUSTUP_DOWNLOAD_TIMEOUT" ) {
201
+ Ok ( s) => NonZeroU64 :: from_str ( & s) . context (
202
+ "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero" ,
203
+ ) ?. get ( ) ,
204
+ Err ( _) => 180 ,
205
+ } ) ;
206
+
197
207
notify_handler ( match backend {
198
208
#[ cfg( feature = "curl-backend" ) ]
199
209
Backend :: Curl => Notification :: UsingCurl ,
@@ -202,7 +212,7 @@ async fn download_file_(
202
212
} ) ;
203
213
204
214
let res = backend
205
- . download_to_path ( url, path, resume_from_partial, Some ( callback) )
215
+ . download_to_path ( url, path, resume_from_partial, Some ( callback) , timeout )
206
216
. await ;
207
217
208
218
notify_handler ( Notification :: DownloadFinished ) ;
@@ -241,9 +251,10 @@ impl Backend {
241
251
path : & Path ,
242
252
resume_from_partial : bool ,
243
253
callback : Option < DownloadCallback < ' _ > > ,
254
+ timeout : Duration ,
244
255
) -> anyhow:: Result < ( ) > {
245
256
let Err ( err) = self
246
- . download_impl ( url, path, resume_from_partial, callback)
257
+ . download_impl ( url, path, resume_from_partial, callback, timeout )
247
258
. await
248
259
else {
249
260
return Ok ( ( ) ) ;
@@ -265,6 +276,7 @@ impl Backend {
265
276
path : & Path ,
266
277
resume_from_partial : bool ,
267
278
callback : Option < DownloadCallback < ' _ > > ,
279
+ timeout : Duration ,
268
280
) -> anyhow:: Result < ( ) > {
269
281
use std:: cell:: RefCell ;
270
282
use std:: fs:: OpenOptions ;
@@ -324,7 +336,7 @@ impl Backend {
324
336
let file = RefCell :: new ( file) ;
325
337
326
338
// TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange.
327
- self . download ( url, resume_from, & |event| {
339
+ self . download ( url, resume_from, timeout , & |event| {
328
340
if let Event :: DownloadDataReceived ( data) = event {
329
341
file. borrow_mut ( )
330
342
. write_all ( data)
@@ -356,13 +368,14 @@ impl Backend {
356
368
self ,
357
369
url : & Url ,
358
370
resume_from : u64 ,
371
+ timeout : Duration ,
359
372
callback : DownloadCallback < ' _ > ,
360
373
) -> anyhow:: Result < ( ) > {
361
374
match self {
362
375
#[ cfg( feature = "curl-backend" ) ]
363
- Self :: Curl => curl:: download ( url, resume_from, callback) ,
376
+ Self :: Curl => curl:: download ( url, resume_from, callback, timeout ) ,
364
377
#[ cfg( any( feature = "reqwest-rustls-tls" , feature = "reqwest-native-tls" ) ) ]
365
- Self :: Reqwest ( tls) => tls. download ( url, resume_from, callback) . await ,
378
+ Self :: Reqwest ( tls) => tls. download ( url, resume_from, callback, timeout ) . await ,
366
379
}
367
380
}
368
381
}
@@ -383,12 +396,13 @@ impl TlsBackend {
383
396
url : & Url ,
384
397
resume_from : u64 ,
385
398
callback : DownloadCallback < ' _ > ,
399
+ timeout : Duration ,
386
400
) -> anyhow:: Result < ( ) > {
387
401
let client = match self {
388
402
#[ cfg( feature = "reqwest-rustls-tls" ) ]
389
- Self :: Rustls => reqwest_be:: rustls_client ( ) ?,
403
+ Self :: Rustls => reqwest_be:: rustls_client ( timeout ) ?,
390
404
#[ cfg( feature = "reqwest-native-tls" ) ]
391
- Self :: NativeTls => & reqwest_be:: CLIENT_NATIVE_TLS ,
405
+ Self :: NativeTls => reqwest_be:: native_tls_client ( timeout ) ? ,
392
406
} ;
393
407
394
408
reqwest_be:: download ( url, resume_from, callback, client) . await
@@ -424,6 +438,7 @@ mod curl {
424
438
url : & Url ,
425
439
resume_from : u64 ,
426
440
callback : & dyn Fn ( Event < ' _ > ) -> Result < ( ) > ,
441
+ timeout : Duration ,
427
442
) -> Result < ( ) > {
428
443
// Fetch either a cached libcurl handle (which will preserve open
429
444
// connections) or create a new one if it isn't listed.
@@ -446,8 +461,8 @@ mod curl {
446
461
let _ = handle. resume_from ( 0 ) ;
447
462
}
448
463
449
- // Take at most 30s to connect
450
- handle. connect_timeout ( Duration :: new ( 30 , 0 ) ) ?;
464
+ // Take at most 3m to connect if the `RUSTUP_DOWNLOAD_TIMEOUT` env var is not set.
465
+ handle. connect_timeout ( timeout ) ?;
451
466
452
467
{
453
468
let cberr = RefCell :: new ( None ) ;
@@ -526,9 +541,7 @@ mod curl {
526
541
#[ cfg( any( feature = "reqwest-rustls-tls" , feature = "reqwest-native-tls" ) ) ]
527
542
mod reqwest_be {
528
543
use std:: io;
529
- #[ cfg( feature = "reqwest-native-tls" ) ]
530
- use std:: sync:: LazyLock ;
531
- #[ cfg( feature = "reqwest-rustls-tls" ) ]
544
+ #[ cfg( any( feature = "reqwest-rustls-tls" , feature = "reqwest-native-tls" ) ) ]
532
545
use std:: sync:: { Arc , OnceLock } ;
533
546
use std:: time:: Duration ;
534
547
@@ -586,11 +599,10 @@ mod reqwest_be {
586
599
. pool_max_idle_per_host ( 0 )
587
600
. gzip ( false )
588
601
. proxy ( Proxy :: custom ( env_proxy) )
589
- . read_timeout ( Duration :: from_secs ( 30 ) )
590
602
}
591
603
592
604
#[ cfg( feature = "reqwest-rustls-tls" ) ]
593
- pub ( super ) fn rustls_client ( ) -> Result < & ' static Client , DownloadError > {
605
+ pub ( super ) fn rustls_client ( timeout : Duration ) -> Result < & ' static Client , DownloadError > {
594
606
if let Some ( client) = CLIENT_RUSTLS_TLS . get ( ) {
595
607
return Ok ( client) ;
596
608
}
@@ -607,6 +619,7 @@ mod reqwest_be {
607
619
tls_config. alpn_protocols = vec ! [ b"h2" . to_vec( ) , b"http/1.1" . to_vec( ) ] ;
608
620
609
621
let client = client_generic ( )
622
+ . read_timeout ( timeout)
610
623
. use_preconfigured_tls ( tls_config)
611
624
. user_agent ( super :: REQWEST_RUSTLS_TLS_USER_AGENT )
612
625
. build ( )
@@ -622,21 +635,24 @@ mod reqwest_be {
622
635
static CLIENT_RUSTLS_TLS : OnceLock < Client > = OnceLock :: new ( ) ;
623
636
624
637
#[ cfg( feature = "reqwest-native-tls" ) ]
625
- pub ( super ) static CLIENT_NATIVE_TLS : LazyLock < Client > = LazyLock :: new ( || {
626
- let catcher = || {
627
- client_generic ( )
628
- . user_agent ( super :: REQWEST_DEFAULT_TLS_USER_AGENT )
629
- . build ( )
630
- } ;
638
+ pub ( super ) fn native_tls_client ( timeout : Duration ) -> Result < & ' static Client , DownloadError > {
639
+ if let Some ( client) = CLIENT_NATIVE_TLS . get ( ) {
640
+ return Ok ( client) ;
641
+ }
631
642
632
- // woah, an unwrap?!
633
- // It's OK. This is the same as what is happening in curl.
634
- //
635
- // The curl::Easy::new() internally assert!s that the initialized
636
- // Easy is not null. Inside reqwest, the errors here would be from
637
- // the TLS library returning a null pointer as well.
638
- catcher ( ) . unwrap ( )
639
- } ) ;
643
+ let client = client_generic ( )
644
+ . read_timeout ( timeout)
645
+ . user_agent ( super :: REQWEST_DEFAULT_TLS_USER_AGENT )
646
+ . build ( )
647
+ . map_err ( DownloadError :: Reqwest ) ?;
648
+
649
+ let _ = CLIENT_NATIVE_TLS . set ( client) ;
650
+
651
+ Ok ( CLIENT_NATIVE_TLS . get ( ) . unwrap ( ) )
652
+ }
653
+
654
+ #[ cfg( feature = "reqwest-native-tls" ) ]
655
+ static CLIENT_NATIVE_TLS : OnceLock < Client > = OnceLock :: new ( ) ;
640
656
641
657
fn env_proxy ( url : & Url ) -> Option < Url > {
642
658
env_proxy:: for_url ( url) . to_url ( )
0 commit comments