Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/user-guide/src/environment-variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@

- `RUSTUP_TERM_WIDTH` (default: none). Allows to override the terminal width for progress bars.

- `RUSTUP_DOWNLOAD_TIMEOUT` *unstable* (default: 180). Allows to override the default
timeout (in seconds) for downloading components.

[directive syntax]: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives
[dc]: https://docs.docker.com/storage/storagedriver/overlayfs-driver/#modifying-files-or-directories
[override]: overrides.md
Expand Down
74 changes: 46 additions & 28 deletions src/download/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
//! Easy file downloading

use std::fs::remove_file;
use std::num::NonZeroU64;
use std::path::Path;
use std::str::FromStr;
use std::time::Duration;

use anyhow::Context;
#[cfg(any(
Expand Down Expand Up @@ -194,6 +197,13 @@ async fn download_file_(
_ => Backend::Curl,
};

let timeout = Duration::from_secs(match process.var("RUSTUP_DOWNLOAD_TIMEOUT") {
Ok(s) => NonZeroU64::from_str(&s).context(
"invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero",
)?.get(),
Err(_) => 180,
});

notify_handler(match backend {
#[cfg(feature = "curl-backend")]
Backend::Curl => Notification::UsingCurl,
Expand All @@ -202,7 +212,7 @@ async fn download_file_(
});

let res = backend
.download_to_path(url, path, resume_from_partial, Some(callback))
.download_to_path(url, path, resume_from_partial, Some(callback), timeout)
.await;

notify_handler(Notification::DownloadFinished);
Expand Down Expand Up @@ -241,9 +251,10 @@ impl Backend {
path: &Path,
resume_from_partial: bool,
callback: Option<DownloadCallback<'_>>,
timeout: Duration,
) -> anyhow::Result<()> {
let Err(err) = self
.download_impl(url, path, resume_from_partial, callback)
.download_impl(url, path, resume_from_partial, callback, timeout)
.await
else {
return Ok(());
Expand All @@ -265,6 +276,7 @@ impl Backend {
path: &Path,
resume_from_partial: bool,
callback: Option<DownloadCallback<'_>>,
timeout: Duration,
) -> anyhow::Result<()> {
use std::cell::RefCell;
use std::fs::OpenOptions;
Expand Down Expand Up @@ -324,7 +336,7 @@ impl Backend {
let file = RefCell::new(file);

// TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange.
self.download(url, resume_from, &|event| {
self.download(url, resume_from, timeout, &|event| {
if let Event::DownloadDataReceived(data) = event {
file.borrow_mut()
.write_all(data)
Expand Down Expand Up @@ -356,13 +368,14 @@ impl Backend {
self,
url: &Url,
resume_from: u64,
timeout: Duration,
callback: DownloadCallback<'_>,
) -> anyhow::Result<()> {
match self {
#[cfg(feature = "curl-backend")]
Self::Curl => curl::download(url, resume_from, callback),
Self::Curl => curl::download(url, resume_from, callback, timeout),
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
Self::Reqwest(tls) => tls.download(url, resume_from, callback).await,
Self::Reqwest(tls) => tls.download(url, resume_from, callback, timeout).await,
}
}
}
Expand All @@ -383,12 +396,13 @@ impl TlsBackend {
url: &Url,
resume_from: u64,
callback: DownloadCallback<'_>,
timeout: Duration,
) -> anyhow::Result<()> {
let client = match self {
#[cfg(feature = "reqwest-rustls-tls")]
Self::Rustls => reqwest_be::rustls_client()?,
Self::Rustls => reqwest_be::rustls_client(timeout)?,
#[cfg(feature = "reqwest-native-tls")]
Self::NativeTls => &reqwest_be::CLIENT_NATIVE_TLS,
Self::NativeTls => reqwest_be::native_tls_client(timeout)?,
};

reqwest_be::download(url, resume_from, callback, client).await
Expand Down Expand Up @@ -424,6 +438,7 @@ mod curl {
url: &Url,
resume_from: u64,
callback: &dyn Fn(Event<'_>) -> Result<()>,
timeout: Duration,
) -> Result<()> {
// Fetch either a cached libcurl handle (which will preserve open
// connections) or create a new one if it isn't listed.
Expand All @@ -446,8 +461,8 @@ mod curl {
let _ = handle.resume_from(0);
}

// Take at most 30s to connect
handle.connect_timeout(Duration::new(30, 0))?;
// Take at most 3m to connect if the `RUSTUP_DOWNLOAD_TIMEOUT` env var is not set.
handle.connect_timeout(timeout)?;

{
let cberr = RefCell::new(None);
Expand Down Expand Up @@ -526,9 +541,7 @@ mod curl {
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
mod reqwest_be {
use std::io;
#[cfg(feature = "reqwest-native-tls")]
use std::sync::LazyLock;
#[cfg(feature = "reqwest-rustls-tls")]
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
use std::sync::{Arc, OnceLock};
use std::time::Duration;

Expand Down Expand Up @@ -586,11 +599,11 @@ mod reqwest_be {
.pool_max_idle_per_host(0)
.gzip(false)
.proxy(Proxy::custom(env_proxy))
.read_timeout(Duration::from_secs(30))
}

#[cfg(feature = "reqwest-rustls-tls")]
pub(super) fn rustls_client() -> Result<&'static Client, DownloadError> {
pub(super) fn rustls_client(timeout: Duration) -> Result<&'static Client, DownloadError> {
// If the client is already initialized, the passed timeout is ignored.
if let Some(client) = CLIENT_RUSTLS_TLS.get() {
return Ok(client);
}
Expand All @@ -607,6 +620,7 @@ mod reqwest_be {
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];

let client = client_generic()
.read_timeout(timeout)
.use_preconfigured_tls(tls_config)
.user_agent(super::REQWEST_RUSTLS_TLS_USER_AGENT)
.build()
Expand All @@ -622,21 +636,25 @@ mod reqwest_be {
static CLIENT_RUSTLS_TLS: OnceLock<Client> = OnceLock::new();

#[cfg(feature = "reqwest-native-tls")]
pub(super) static CLIENT_NATIVE_TLS: LazyLock<Client> = LazyLock::new(|| {
let catcher = || {
client_generic()
.user_agent(super::REQWEST_DEFAULT_TLS_USER_AGENT)
.build()
};
pub(super) fn native_tls_client(timeout: Duration) -> Result<&'static Client, DownloadError> {
// If the client is already initialized, the passed timeout is ignored.
if let Some(client) = CLIENT_NATIVE_TLS.get() {
return Ok(client);
}

// woah, an unwrap?!
// It's OK. This is the same as what is happening in curl.
//
// The curl::Easy::new() internally assert!s that the initialized
// Easy is not null. Inside reqwest, the errors here would be from
// the TLS library returning a null pointer as well.
catcher().unwrap()
});
let client = client_generic()
.read_timeout(timeout)
.user_agent(super::REQWEST_DEFAULT_TLS_USER_AGENT)
.build()
.map_err(DownloadError::Reqwest)?;

let _ = CLIENT_NATIVE_TLS.set(client);

Ok(CLIENT_NATIVE_TLS.get().unwrap())
}

#[cfg(feature = "reqwest-native-tls")]
static CLIENT_NATIVE_TLS: OnceLock<Client> = OnceLock::new();

fn env_proxy(url: &Url) -> Option<Url> {
env_proxy::for_url(url).to_url()
Expand Down
19 changes: 17 additions & 2 deletions src/download/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use tempfile::TempDir;
mod curl {
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;

use url::Url;

Expand All @@ -36,7 +37,13 @@ mod curl {

let from_url = Url::from_file_path(&from_path).unwrap();
Backend::Curl
.download_to_path(&from_url, &target_path, true, None)
.download_to_path(
&from_url,
&target_path,
true,
None,
Duration::from_secs(180),
)
.await
.expect("Test download failed");

Expand Down Expand Up @@ -83,6 +90,7 @@ mod curl {

Ok(())
}),
Duration::from_secs(180),
)
.await
.expect("Test download failed");
Expand Down Expand Up @@ -185,7 +193,13 @@ mod reqwest {

let from_url = Url::from_file_path(&from_path).unwrap();
Backend::Reqwest(TlsBackend::NativeTls)
.download_to_path(&from_url, &target_path, true, None)
.download_to_path(
&from_url,
&target_path,
true,
None,
Duration::from_secs(180),
)
.await
.expect("Test download failed");

Expand Down Expand Up @@ -232,6 +246,7 @@ mod reqwest {

Ok(())
}),
Duration::from_secs(180),
)
.await
.expect("Test download failed");
Expand Down
Loading