Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/user-guide/src/environment-variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@
- `RUSTUP_VERSION` (default: none). Overrides the rustup version (e.g. `1.27.1`)
to be downloaded when executing `rustup-init.sh` or `rustup self update`.

- `RUSTUP_AUTHORIZATION_HEADER` (default: none). The value to an `Authorization` HTTP
header that should be added to all requests made by rustup. This is meant for use when
using an alternate rustup distribution server (through the `RUSTUP_DIST_SERVER`
environment variable) which requires authentication such as basic username:password
credentials or a bearer token.

- `RUSTUP_PROXY_AUTHORIZATION_HEADER` (default: none). This is like the `RUSTUP_AUTHORIZATION_HEADER` except
this will add a `Proxy-Authorization` HTTP header. This is for authenticating to forward
proxies (via the `HTTP_PROXY` or `HTTPS_PROXY`) environment variables.

- `RUSTUP_IO_THREADS` *unstable* (default: reported cpu count, max 8). Sets the
number of threads to perform close IO in. Set to `1` to force
single-threaded IO for troubleshooting, or an arbitrary number to override
Expand Down
157 changes: 130 additions & 27 deletions src/download/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,14 @@ async fn download_file_(
};

let res = backend
.download_to_path(url, path, resume_from_partial, Some(callback), timeout)
.download_to_path(
url,
path,
resume_from_partial,
Some(callback),
timeout,
process,
)
.await;

// The notification should only be sent if the download was successful (i.e. didn't timeout)
Expand Down Expand Up @@ -253,9 +260,10 @@ impl Backend {
resume_from_partial: bool,
callback: Option<DownloadCallback<'_>>,
timeout: Duration,
process: &Process,
) -> anyhow::Result<()> {
let Err(err) = self
.download_impl(url, path, resume_from_partial, callback, timeout)
.download_impl(url, path, resume_from_partial, callback, timeout, process)
.await
else {
return Ok(());
Expand All @@ -278,6 +286,7 @@ impl Backend {
resume_from_partial: bool,
callback: Option<DownloadCallback<'_>>,
timeout: Duration,
process: &Process,
) -> anyhow::Result<()> {
use std::cell::RefCell;
use std::fs::OpenOptions;
Expand Down Expand Up @@ -337,17 +346,23 @@ impl Backend {
let file = RefCell::new(file);

// TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange.
self.download(url, resume_from, timeout, &|event| {
if let Event::DownloadDataReceived(data) = event {
file.borrow_mut()
.write_all(data)
.context("unable to write download to disk")?;
}
match callback {
Some(cb) => cb(event),
None => Ok(()),
}
})
self.download(
url,
resume_from,
timeout,
&|event| {
if let Event::DownloadDataReceived(data) = event {
file.borrow_mut()
.write_all(data)
.context("unable to write download to disk")?;
}
match callback {
Some(cb) => cb(event),
None => Ok(()),
}
},
process,
)
.await?;

file.borrow_mut()
Expand All @@ -371,12 +386,16 @@ impl Backend {
resume_from: u64,
timeout: Duration,
callback: DownloadCallback<'_>,
process: &Process,
) -> anyhow::Result<()> {
match self {
#[cfg(feature = "curl-backend")]
Self::Curl => curl::download(url, resume_from, callback, timeout),
Self::Curl => curl::download(url, resume_from, callback, timeout, process),
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
Self::Reqwest(tls) => tls.download(url, resume_from, callback, timeout).await,
Self::Reqwest(tls) => {
tls.download(url, resume_from, callback, timeout, process)
.await
}
}
}
}
Expand All @@ -398,12 +417,13 @@ impl TlsBackend {
resume_from: u64,
callback: DownloadCallback<'_>,
timeout: Duration,
process: &Process,
) -> anyhow::Result<()> {
let client = match self {
#[cfg(feature = "reqwest-rustls-tls")]
Self::Rustls => reqwest_be::rustls_client(timeout)?,
Self::Rustls => reqwest_be::rustls_client(timeout, process)?,
#[cfg(feature = "reqwest-native-tls")]
Self::NativeTls => reqwest_be::native_tls_client(timeout)?,
Self::NativeTls => reqwest_be::native_tls_client(timeout, process)?,
};

reqwest_be::download(url, resume_from, callback, client).await
Expand All @@ -430,16 +450,42 @@ mod curl {
use std::time::Duration;

use anyhow::{Context, Result};
use curl::easy::Easy;
use curl::easy::{Easy, List};
use tracing::debug;
use url::Url;

use super::{DownloadError, Event};
use super::{DownloadError, Event, Process};

macro_rules! add_header_for_curl_easy_handle {
($handle:ident, $process:ident, $env_var:literal, $header_name:literal) => {
if let Some(rustup_header_value) = $process.var_opt($env_var).map_err(|error| {
anyhow::anyhow!(
"Internal error getting `{}` environment variable: {}",
$env_var,
anyhow::format_err!(error)
)
})? {
let mut list = List::new();
list.append(format!("{}: {}", $header_name, rustup_header_value).as_str())
.map_err(|_| {
// The error could contain sensitive data so give a generic error instead.
anyhow::anyhow!("Failed to add `{}` HTTP header.", $header_name)
})?;
$handle.http_headers(list).map_err(|_| {
// The error could contain sensitive data so give a generic error instead.
anyhow::anyhow!("Failed to add headers to curl easy handle.")
})?;
debug!("Added `{}` header.", $header_name);
}
};
}

pub(super) fn download(
url: &Url,
resume_from: u64,
callback: &dyn Fn(Event<'_>) -> Result<()>,
timeout: Duration,
process: &Process,
) -> Result<()> {
// Fetch either a cached libcurl handle (which will preserve open
// connections) or create a new one if it isn't listed.
Expand All @@ -453,6 +499,18 @@ mod curl {
handle.url(url.as_ref())?;
handle.follow_location(true)?;
handle.useragent(super::CURL_USER_AGENT)?;
add_header_for_curl_easy_handle!(
handle,
process,
"RUSTUP_AUTHORIZATION_HEADER",
"Authorization"
);
add_header_for_curl_easy_handle!(
handle,
process,
"RUSTUP_PROXY_AUTHORIZATION_HEADER",
"Proxy-Authorization"
);

if resume_from > 0 {
handle.resume_from(resume_from)?;
Expand Down Expand Up @@ -557,7 +615,33 @@ mod reqwest_be {
use tokio_stream::StreamExt;
use url::Url;

use super::{DownloadError, Event};
use super::{DownloadError, Event, Process, debug};

macro_rules! add_header_for_client_builder {
($client_builder:ident, $process:ident, $env_var:literal, $header_name:path) => {
if let Some(rustup_header_value) = $process.var_opt($env_var).map_err(|_| {
// The error could contain sensitive data so give a generic error instead.
DownloadError::Message(format!(
"Internal error getting `{}` environment variable",
$env_var
))
})? {
let mut headers = header::HeaderMap::new();
let mut auth_value =
header::HeaderValue::from_str(&rustup_header_value).map_err(|_| {
// The error could contain sensitive data so give a generic error instead.
DownloadError::Message(format!(
"The `{}` environment variable set to an invalid HTTP header value.",
$env_var
))
})?;
auth_value.set_sensitive(true);
headers.insert($header_name, auth_value);
$client_builder = $client_builder.default_headers(headers);
debug!("Added `{}` header.", $header_name);
}
};
}

pub(super) async fn download(
url: &Url,
Expand Down Expand Up @@ -592,18 +676,34 @@ mod reqwest_be {
Ok(())
}

fn client_generic() -> ClientBuilder {
Client::builder()
fn client_generic(process: &Process) -> Result<ClientBuilder, DownloadError> {
let mut client_builder = Client::builder()
// HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying
// `hyper` library that causes the `reqwest` client to hang in some cases.
// See <https://github.com/hyperium/hyper/issues/2312> for more details.
.pool_max_idle_per_host(0)
.gzip(false)
.proxy(Proxy::custom(env_proxy))
.proxy(Proxy::custom(env_proxy));
add_header_for_client_builder!(
client_builder,
process,
"RUSTUP_AUTHORIZATION_HEADER",
header::AUTHORIZATION
);
add_header_for_client_builder!(
client_builder,
process,
"RUSTUP_PROXY_AUTHORIZATION_HEADER",
header::PROXY_AUTHORIZATION
);
Ok(client_builder)
}

#[cfg(feature = "reqwest-rustls-tls")]
pub(super) fn rustls_client(timeout: Duration) -> Result<&'static Client, DownloadError> {
pub(super) fn rustls_client(
timeout: Duration,
process: &Process,
) -> Result<&'static Client, DownloadError> {
// If the client is already initialized, the passed timeout is ignored.
if let Some(client) = CLIENT_RUSTLS_TLS.get() {
return Ok(client);
Expand All @@ -627,7 +727,7 @@ mod reqwest_be {
.with_no_client_auth();
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];

let client = client_generic()
let client = client_generic(process)?
.read_timeout(timeout)
.use_preconfigured_tls(tls_config)
.user_agent(super::REQWEST_RUSTLS_TLS_USER_AGENT)
Expand All @@ -644,13 +744,16 @@ mod reqwest_be {
static CLIENT_RUSTLS_TLS: OnceLock<Client> = OnceLock::new();

#[cfg(feature = "reqwest-native-tls")]
pub(super) fn native_tls_client(timeout: Duration) -> Result<&'static Client, DownloadError> {
pub(super) fn native_tls_client(
timeout: Duration,
process: &Process,
) -> Result<&'static Client, DownloadError> {
// If the client is already initialized, the passed timeout is ignored.
if let Some(client) = CLIENT_NATIVE_TLS.get() {
return Ok(client);
}

let client = client_generic()
let client = client_generic(process)?
.read_timeout(timeout)
.user_agent(super::REQWEST_DEFAULT_TLS_USER_AGENT)
.build()
Expand Down
10 changes: 10 additions & 0 deletions src/download/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ mod curl {
use super::{scrub_env, serve_file, tmp_dir, write_file};
use crate::download::{Backend, Event};

#[cfg(feature = "test")]
use crate::process::TestProcess;

#[tokio::test]
async fn partially_downloaded_file_gets_resumed_from_byte_offset() {
let tmpdir = tmp_dir();
Expand All @@ -43,6 +46,7 @@ mod curl {
true,
None,
Duration::from_secs(180),
&TestProcess::default().process,
)
.await
.expect("Test download failed");
Expand Down Expand Up @@ -91,6 +95,7 @@ mod curl {
Ok(())
}),
Duration::from_secs(180),
&TestProcess::default().process,
)
.await
.expect("Test download failed");
Expand Down Expand Up @@ -120,6 +125,9 @@ mod reqwest {
use super::{scrub_env, serve_file, tmp_dir, write_file};
use crate::download::{Backend, Event, TlsBackend};

#[cfg(feature = "test")]
use crate::process::TestProcess;

// Tests for correctly retrieving the proxy (host, port) tuple from $https_proxy
#[tokio::test]
async fn read_basic_proxy_params() {
Expand Down Expand Up @@ -199,6 +207,7 @@ mod reqwest {
true,
None,
Duration::from_secs(180),
&TestProcess::default().process,
)
.await
.expect("Test download failed");
Expand Down Expand Up @@ -247,6 +256,7 @@ mod reqwest {
Ok(())
}),
Duration::from_secs(180),
&TestProcess::default().process,
)
.await
.expect("Test download failed");
Expand Down
1 change: 1 addition & 0 deletions tests/resources/logs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.log
Loading