Skip to content

Commit 306b659

Browse files
feat(downloads): interleave the downloads with their installations
Even though downloads are done concurrently, the installations are done sequentially. This means that, as downloads complete, they are in a queue (an mpsc channel) waiting to be consumed by the future responsible for the (sequential) installations.
1 parent 41ed3a2 commit 306b659

File tree

1 file changed

+108
-59
lines changed

1 file changed

+108
-59
lines changed

src/dist/manifestation.rs

Lines changed: 108 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ mod tests;
77
use std::path::Path;
88

99
use anyhow::{Context, Result, anyhow, bail};
10-
use futures_util::stream::StreamExt;
10+
use futures_util::{future::join, stream::StreamExt};
1111
use std::sync::Arc;
12-
use tokio::sync::Semaphore;
12+
use tokio::sync::{Semaphore, mpsc};
1313
use tracing::info;
1414

1515
use crate::dist::component::{
@@ -153,7 +153,6 @@ impl Manifestation {
153153
let altered = tmp_cx.dist_server != DEFAULT_DIST_SERVER;
154154

155155
// Download component packages and validate hashes
156-
let mut things_to_install: Vec<(Component, CompressionKind, File)> = Vec::new();
157156
let mut things_downloaded: Vec<String> = Vec::new();
158157
let components = update.components_urls_and_hashes(new_manifest)?;
159158
let components_len = components.len();
@@ -170,49 +169,7 @@ impl Manifestation {
170169
.and_then(|s| s.parse().ok())
171170
.unwrap_or(DEFAULT_MAX_RETRIES);
172171

173-
info!("downloading component(s)");
174-
for (component, _, url, _) in components.clone() {
175-
(download_cfg.notify_handler)(Notification::DownloadingComponent(
176-
&component.short_name(new_manifest),
177-
&self.target_triple,
178-
component.target.as_ref(),
179-
&url,
180-
));
181-
}
182-
183-
let semaphore = Arc::new(Semaphore::new(num_channels));
184-
let component_stream =
185-
tokio_stream::iter(components.into_iter()).map(|(component, format, url, hash)| {
186-
let sem = semaphore.clone();
187-
async move {
188-
let _permit = sem.acquire().await.unwrap();
189-
self.download_component(
190-
component,
191-
format,
192-
url,
193-
hash,
194-
altered,
195-
tmp_cx,
196-
download_cfg,
197-
max_retries,
198-
new_manifest,
199-
)
200-
.await
201-
}
202-
});
203-
if num_channels > 0 {
204-
let results = component_stream
205-
.buffered(components_len)
206-
.collect::<Vec<_>>()
207-
.await;
208-
for result in results {
209-
let (component, format, downloaded_file, hash) = result?;
210-
things_downloaded.push(hash);
211-
things_to_install.push((component, format, downloaded_file));
212-
}
213-
}
214-
215-
// Begin transaction
172+
// Begin transaction before the downloads, as installations are interleaved with those
216173
let mut tx = Transaction::new(
217174
prefix.clone(),
218175
tmp_cx,
@@ -224,6 +181,16 @@ impl Manifestation {
224181
// to uninstall it first.
225182
tx = self.maybe_handle_v2_upgrade(&config, tx, download_cfg.process)?;
226183

184+
info!("downloading component(s)");
185+
for (component, _, url, _) in components.clone() {
186+
(download_cfg.notify_handler)(Notification::DownloadingComponent(
187+
&component.short_name(new_manifest),
188+
&self.target_triple,
189+
component.target.as_ref(),
190+
&url,
191+
));
192+
}
193+
227194
// Uninstall components
228195
for component in &update.components_to_uninstall {
229196
let notification = if implicit_modify {
@@ -246,17 +213,94 @@ impl Manifestation {
246213
)?;
247214
}
248215

249-
// Install components
250-
for (component, format, installer_file) in things_to_install {
251-
tx = self.install_component(
252-
component,
253-
format,
254-
installer_file,
255-
tmp_cx,
256-
download_cfg,
257-
new_manifest,
258-
tx,
259-
)?;
216+
if num_channels > 0 {
217+
// Create a channel to communicate whenever a download is done and the component can be installed
218+
// The `mpsc` channel was used as we need to send many messages from one producer (download's thread) to one consumer (install's thread)
219+
// This is recommended in the official docs: https://docs.rs/tokio/latest/tokio/sync/index.html#mpsc-channel
220+
let total_components = components.len();
221+
let (download_tx, mut download_rx) =
222+
mpsc::channel::<Result<(Component, CompressionKind, File)>>(total_components);
223+
224+
let semaphore = Arc::new(Semaphore::new(num_channels));
225+
let component_stream =
226+
tokio_stream::iter(components.into_iter()).map(|(component, format, url, hash)| {
227+
let sem = semaphore.clone();
228+
let download_tx_cloned = download_tx.clone();
229+
async move {
230+
let _permit = sem.acquire().await.unwrap();
231+
self.download_component(
232+
component,
233+
format,
234+
url,
235+
hash,
236+
altered,
237+
tmp_cx,
238+
download_cfg,
239+
max_retries,
240+
new_manifest,
241+
download_tx_cloned,
242+
)
243+
.await
244+
}
245+
});
246+
247+
let mut stream = component_stream.buffered(num_channels);
248+
let (download_results, install_result) = join(
249+
async {
250+
let mut hashes = Vec::new();
251+
while let Some(result) = stream.next().await {
252+
match result {
253+
Ok(hash) => {
254+
hashes.push(hash);
255+
}
256+
Err(e) => {
257+
let _ = download_tx.send(Err(e)).await;
258+
}
259+
}
260+
}
261+
hashes
262+
},
263+
async {
264+
let mut current_tx = tx;
265+
let mut counter = 0;
266+
loop {
267+
if counter >= total_components {
268+
break;
269+
}
270+
if let Some(message) = download_rx.recv().await {
271+
match message {
272+
Ok((component, format, installer_file)) => {
273+
match self.install_component(
274+
component,
275+
format,
276+
installer_file,
277+
tmp_cx,
278+
download_cfg,
279+
new_manifest,
280+
current_tx,
281+
) {
282+
Ok(new_tx) => {
283+
current_tx = new_tx;
284+
}
285+
Err(e) => {
286+
return Err(e);
287+
}
288+
}
289+
}
290+
Err(e) => {
291+
return Err(e);
292+
}
293+
}
294+
counter += 1;
295+
}
296+
}
297+
Ok(current_tx)
298+
},
299+
)
300+
.await;
301+
302+
things_downloaded = download_results;
303+
tx = install_result?;
260304
}
261305

262306
// Install new distribution manifest
@@ -508,7 +552,8 @@ impl Manifestation {
508552
download_cfg: &DownloadCfg<'_>,
509553
max_retries: usize,
510554
new_manifest: &Manifest,
511-
) -> Result<(Component, CompressionKind, File, String)> {
555+
notification_tx: mpsc::Sender<Result<(Component, CompressionKind, File)>>,
556+
) -> Result<String> {
512557
use tokio_retry::{RetryIf, strategy::FixedInterval};
513558

514559
let url = if altered {
@@ -537,9 +582,13 @@ impl Manifestation {
537582
.await
538583
.with_context(|| RustupError::ComponentDownloadFailed(component.name(new_manifest)))?;
539584

540-
Ok((component, format, downloaded_file, hash))
585+
let _ = notification_tx
586+
.send(Ok((component.clone(), format, downloaded_file)))
587+
.await;
588+
Ok(hash)
541589
}
542590

591+
#[allow(clippy::too_many_arguments)]
543592
fn install_component<'a>(
544593
&self,
545594
component: Component,

0 commit comments

Comments
 (0)