@@ -7,9 +7,9 @@ mod tests;
7
7
use std:: path:: Path ;
8
8
9
9
use anyhow:: { Context , Result , anyhow, bail} ;
10
- use futures_util:: stream:: StreamExt ;
10
+ use futures_util:: { future :: join , stream:: StreamExt } ;
11
11
use std:: sync:: Arc ;
12
- use tokio:: sync:: Semaphore ;
12
+ use tokio:: sync:: { Semaphore , mpsc } ;
13
13
use tracing:: info;
14
14
15
15
use crate :: dist:: component:: {
@@ -153,7 +153,6 @@ impl Manifestation {
153
153
let altered = tmp_cx. dist_server != DEFAULT_DIST_SERVER ;
154
154
155
155
// Download component packages and validate hashes
156
- let mut things_to_install: Vec < ( Component , CompressionKind , File ) > = Vec :: new ( ) ;
157
156
let mut things_downloaded: Vec < String > = Vec :: new ( ) ;
158
157
let components = update. components_urls_and_hashes ( new_manifest) ?;
159
158
let components_len = components. len ( ) ;
@@ -170,49 +169,7 @@ impl Manifestation {
170
169
. and_then ( |s| s. parse ( ) . ok ( ) )
171
170
. unwrap_or ( DEFAULT_MAX_RETRIES ) ;
172
171
173
- info ! ( "downloading component(s)" ) ;
174
- for ( component, _, url, _) in components. clone ( ) {
175
- ( download_cfg. notify_handler ) ( Notification :: DownloadingComponent (
176
- & component. short_name ( new_manifest) ,
177
- & self . target_triple ,
178
- component. target . as_ref ( ) ,
179
- & url,
180
- ) ) ;
181
- }
182
-
183
- let semaphore = Arc :: new ( Semaphore :: new ( num_channels) ) ;
184
- let component_stream =
185
- tokio_stream:: iter ( components. into_iter ( ) ) . map ( |( component, format, url, hash) | {
186
- let sem = semaphore. clone ( ) ;
187
- async move {
188
- let _permit = sem. acquire ( ) . await . unwrap ( ) ;
189
- self . download_component (
190
- component,
191
- format,
192
- url,
193
- hash,
194
- altered,
195
- tmp_cx,
196
- download_cfg,
197
- max_retries,
198
- new_manifest,
199
- )
200
- . await
201
- }
202
- } ) ;
203
- if num_channels > 0 {
204
- let results = component_stream
205
- . buffered ( components_len)
206
- . collect :: < Vec < _ > > ( )
207
- . await ;
208
- for result in results {
209
- let ( component, format, downloaded_file, hash) = result?;
210
- things_downloaded. push ( hash) ;
211
- things_to_install. push ( ( component, format, downloaded_file) ) ;
212
- }
213
- }
214
-
215
- // Begin transaction
172
+ // Begin transaction before the downloads, as installations are interleaved with those
216
173
let mut tx = Transaction :: new (
217
174
prefix. clone ( ) ,
218
175
tmp_cx,
@@ -246,17 +203,104 @@ impl Manifestation {
246
203
) ?;
247
204
}
248
205
249
- // Install components
250
- for ( component, format, installer_file) in things_to_install {
251
- tx = self . install_component (
252
- component,
253
- format,
254
- installer_file,
255
- tmp_cx,
256
- download_cfg,
257
- new_manifest,
258
- tx,
259
- ) ?;
206
+ if num_channels > 0 {
207
+ info ! ( "downloading component(s)" ) ;
208
+ for ( component, _, url, _) in components. clone ( ) {
209
+ ( download_cfg. notify_handler ) ( Notification :: DownloadingComponent (
210
+ & component. short_name ( new_manifest) ,
211
+ & self . target_triple ,
212
+ component. target . as_ref ( ) ,
213
+ & url,
214
+ ) ) ;
215
+ }
216
+
217
+ // Create a channel to communicate whenever a download is done and the component can be installed
218
+ // The `mpsc` channel was used as we need to send many messages from one producer (download's thread) to one consumer (install's thread)
219
+ // This is recommended in the official docs: https://docs.rs/tokio/latest/tokio/sync/index.html#mpsc-channel
220
+ let total_components = components. len ( ) ;
221
+ let ( download_tx, mut download_rx) =
222
+ mpsc:: channel :: < Result < ( Component , CompressionKind , File ) > > ( total_components) ;
223
+
224
+ let semaphore = Arc :: new ( Semaphore :: new ( num_channels) ) ;
225
+ let component_stream =
226
+ tokio_stream:: iter ( components. into_iter ( ) ) . map ( |( component, format, url, hash) | {
227
+ let sem = semaphore. clone ( ) ;
228
+ let download_tx_cloned = download_tx. clone ( ) ;
229
+ async move {
230
+ let _permit = sem. acquire ( ) . await . unwrap ( ) ;
231
+ self . download_component (
232
+ component,
233
+ format,
234
+ url,
235
+ hash,
236
+ altered,
237
+ tmp_cx,
238
+ download_cfg,
239
+ max_retries,
240
+ new_manifest,
241
+ download_tx_cloned,
242
+ )
243
+ . await
244
+ }
245
+ } ) ;
246
+
247
+ let mut stream = component_stream. buffered ( num_channels) ;
248
+ let ( download_results, install_result) = join (
249
+ async {
250
+ let mut hashes = Vec :: new ( ) ;
251
+ while let Some ( result) = stream. next ( ) . await {
252
+ match result {
253
+ Ok ( hash) => {
254
+ hashes. push ( hash) ;
255
+ }
256
+ Err ( e) => {
257
+ let _ = download_tx. send ( Err ( e) ) . await ;
258
+ }
259
+ }
260
+ }
261
+ hashes
262
+ } ,
263
+ async {
264
+ let mut current_tx = tx;
265
+ let mut counter = 0 ;
266
+ loop {
267
+ if counter >= total_components {
268
+ break ;
269
+ }
270
+ if let Some ( message) = download_rx. recv ( ) . await {
271
+ match message {
272
+ Ok ( ( component, format, installer_file) ) => {
273
+ match self . install_component (
274
+ component,
275
+ format,
276
+ installer_file,
277
+ tmp_cx,
278
+ download_cfg,
279
+ new_manifest,
280
+ current_tx,
281
+ ) {
282
+ Ok ( new_tx) => {
283
+ current_tx = new_tx;
284
+ }
285
+ Err ( e) => {
286
+ return Err ( e) ;
287
+ }
288
+ }
289
+ }
290
+ Err ( e) => {
291
+ return Err ( e) ;
292
+ }
293
+ }
294
+ counter += 1 ;
295
+ }
296
+ }
297
+ Ok ( current_tx)
298
+ } ,
299
+ )
300
+ . await ;
301
+
302
+ things_downloaded = download_results;
303
+ tx = install_result?;
260
304
}
261
305
262
306
// Install new distribution manifest
@@ -508,7 +552,8 @@ impl Manifestation {
508
552
download_cfg : & DownloadCfg < ' _ > ,
509
553
max_retries : usize ,
510
554
new_manifest : & Manifest ,
511
- ) -> Result < ( Component , CompressionKind , File , String ) > {
555
+ notification_tx : mpsc:: Sender < Result < ( Component , CompressionKind , File ) > > ,
556
+ ) -> Result < String > {
512
557
use tokio_retry:: { RetryIf , strategy:: FixedInterval } ;
513
558
514
559
let url = if altered {
@@ -537,9 +582,13 @@ impl Manifestation {
537
582
. await
538
583
. with_context ( || RustupError :: ComponentDownloadFailed ( component. name ( new_manifest) ) ) ?;
539
584
540
- Ok ( ( component, format, downloaded_file, hash) )
585
+ let _ = notification_tx
586
+ . send ( Ok ( ( component. clone ( ) , format, downloaded_file) ) )
587
+ . await ;
588
+ Ok ( hash)
541
589
}
542
590
591
+ #[ allow( clippy:: too_many_arguments) ]
543
592
fn install_component < ' a > (
544
593
& self ,
545
594
component : Component ,
0 commit comments