|
1 | 1 | use super::*; |
2 | 2 |
|
| 3 | +use std::future::Future; |
| 4 | +use std::ops::Range; |
| 5 | +use std::sync::atomic::AtomicUsize; |
| 6 | +use std::sync::atomic::Ordering; |
3 | 7 | use tokio::io::AsyncBufReadExt; |
4 | 8 | use tokio::io::AsyncWriteExt; |
5 | 9 | use tokio::io::BufReader; |
@@ -63,7 +67,13 @@ async fn resume_download_when_disconnected() { |
63 | 67 | let body = &file[start..next]; |
64 | 68 |
|
65 | 69 | let size = body.len(); |
66 | | - let msg = format!("{header}\r\n{size}\r\n{body}\r\n"); |
| 70 | + let msg = format!("{header}\r\n{size:x}\r\n{body}\r\n"); |
| 71 | + // if this is the last chunk, send also terminating 0-length chunk |
| 72 | + let msg = if next == file.len() { |
| 73 | + format!("{msg}0\r\n\r\n") |
| 74 | + } else { |
| 75 | + msg |
| 76 | + }; |
67 | 77 | debug!("sending message = {msg}"); |
68 | 78 | writer.write_all(msg.as_bytes()).await.unwrap(); |
69 | 79 | writer.flush().await.unwrap(); |
@@ -330,3 +340,147 @@ async fn resumed_download_doesnt_leave_leftovers() { |
330 | 340 |
|
331 | 341 | server_task.abort(); |
332 | 342 | } |
| 343 | + |
| 344 | +#[tokio::test] |
| 345 | +async fn resume_max_5_times() { |
| 346 | + let request_count = Arc::new(AtomicUsize::new(0)); |
| 347 | + let rc = request_count.clone(); |
| 348 | + |
| 349 | + let handler = move |_, _| { |
| 350 | + let rc = rc.clone(); |
| 351 | + async move { |
| 352 | + rc.fetch_add(1, Ordering::SeqCst); |
| 353 | + // Always respond with only first chunk, never completing the response, triggering retries |
| 354 | + let header = "\ |
| 355 | + HTTP/1.1 200 OK\r\n\ |
| 356 | + transfer-encoding: chunked\r\n\ |
| 357 | + connection: close\r\n\ |
| 358 | + content-type: application/octet-stream\r\n\ |
| 359 | + accept-ranges: bytes\r\n"; |
| 360 | + |
| 361 | + let body = "AAAA"; |
| 362 | + format!("{header}\r\n4\r\n{body}\r\n") |
| 363 | + } |
| 364 | + }; |
| 365 | + let (port, server_task) = spawn_server(handler).await; |
| 366 | + |
| 367 | + // Wait until task binds a listener on the TCP port |
| 368 | + tokio::time::sleep(std::time::Duration::from_millis(50)).await; |
| 369 | + |
| 370 | + let tmpdir = TempDir::new().unwrap(); |
| 371 | + let target_path = tmpdir.path().join("partial_download"); |
| 372 | + |
| 373 | + let downloader = Downloader::new(target_path, None, CloudHttpConfig::test_value()); |
| 374 | + let url = DownloadInfo::new(&format!("http://localhost:{port}/")); |
| 375 | + |
| 376 | + let err = downloader.download(&url).await.unwrap_err(); |
| 377 | + assert!(matches!(err, DownloadError::Request(_))); |
| 378 | + assert!(err.to_string().contains("error decoding response body")); |
| 379 | + |
| 380 | + downloader.cleanup().await.unwrap(); |
| 381 | + |
| 382 | + server_task.abort(); |
| 383 | + |
| 384 | + assert_eq!(request_count.load(Ordering::SeqCst), 5); |
| 385 | +} |
| 386 | + |
| 387 | +// If we succeed before max retries, we should not do more requests. |
| 388 | +#[tokio::test] |
| 389 | +async fn only_retry_until_success() { |
| 390 | + let file = "AAAABBBBCCCCDDDD"; |
| 391 | + let request_count = Arc::new(AtomicUsize::new(0)); |
| 392 | + let rc = request_count.clone(); |
| 393 | + |
| 394 | + let handler = move |_, request_count| { |
| 395 | + let rc = rc.clone(); |
| 396 | + async move { |
| 397 | + rc.fetch_add(1, Ordering::SeqCst); |
| 398 | + // Always respond with only first chunk, never completing the response, triggering retries |
| 399 | + let header = "\ |
| 400 | + HTTP/1.1 200 OK\r\n\ |
| 401 | + transfer-encoding: chunked\r\n\ |
| 402 | + connection: close\r\n\ |
| 403 | + content-type: application/octet-stream\r\n\ |
| 404 | + accept-ranges: bytes\r\n"; |
| 405 | + |
| 406 | + // On the 2nd request, send the full response, should trigger only 1 retry |
| 407 | + if request_count == 0 { |
| 408 | + let body = "AAAA"; |
| 409 | + format!("{header}\r\n4\r\n{body}\r\n") |
| 410 | + } else { |
| 411 | + let body = file; |
| 412 | + let len = file.len(); |
| 413 | + format!("{header}\r\n{len:x}\r\n{body}\r\n0\r\n\r\n") |
| 414 | + } |
| 415 | + } |
| 416 | + }; |
| 417 | + let (port, server_task) = spawn_server(handler).await; |
| 418 | + |
| 419 | + // Wait until task binds a listener on the TCP port |
| 420 | + tokio::time::sleep(std::time::Duration::from_millis(50)).await; |
| 421 | + |
| 422 | + let tmpdir = TempDir::new().unwrap(); |
| 423 | + let target_path = tmpdir.path().join("partial_download"); |
| 424 | + |
| 425 | + let downloader = Downloader::new(target_path, None, CloudHttpConfig::test_value()); |
| 426 | + let url = DownloadInfo::new(&format!("http://localhost:{port}/")); |
| 427 | + |
| 428 | + downloader.download(&url).await.unwrap(); |
| 429 | + let saved_file = std::fs::read_to_string(downloader.filename()).unwrap(); |
| 430 | + assert_eq!(saved_file, file); |
| 431 | + |
| 432 | + downloader.cleanup().await.unwrap(); |
| 433 | + |
| 434 | + server_task.abort(); |
| 435 | + |
| 436 | + assert_eq!(request_count.load(Ordering::SeqCst), 2); |
| 437 | +} |
| 438 | + |
| 439 | +async fn spawn_server<F, Fut>(handle_request: F) -> (u16, tokio::task::JoinHandle<()>) |
| 440 | +where |
| 441 | + F: Fn(Option<Range<usize>>, u32) -> Fut + 'static + Send + Sync + Clone, |
| 442 | + Fut: Future<Output = String> + Send + Sync + 'static, |
| 443 | +{ |
| 444 | + let listener = TcpListener::bind("localhost:0").await.unwrap(); |
| 445 | + let port = listener.local_addr().unwrap().port(); |
| 446 | + let server_task = tokio::spawn(async move { |
| 447 | + let mut request_count = 0; |
| 448 | + while let Ok((mut stream, _addr)) = listener.accept().await { |
| 449 | + let handler = handle_request.clone(); |
| 450 | + tokio::spawn(async move { |
| 451 | + let (reader, mut writer) = stream.split(); |
| 452 | + let request_range = parse_request_range(reader).await; |
| 453 | + |
| 454 | + let response: String = handler(request_range, request_count).await; |
| 455 | + |
| 456 | + writer.write_all(response.as_bytes()).await.unwrap(); |
| 457 | + writer.flush().await.unwrap(); |
| 458 | + }); |
| 459 | + request_count += 1; |
| 460 | + } |
| 461 | + }); |
| 462 | + (port, server_task) |
| 463 | +} |
| 464 | + |
| 465 | +/// Parse HTTP request from a BufReader, stopping at empty line |
| 466 | +async fn parse_request_range(reader: tokio::net::tcp::ReadHalf<'_>) -> Option<Range<usize>> { |
| 467 | + let reader = BufReader::new(reader); |
| 468 | + let mut lines = reader.lines(); |
| 469 | + let mut range = None; |
| 470 | + |
| 471 | + while let Ok(Some(line)) = lines.next_line().await { |
| 472 | + if line.to_ascii_lowercase().contains("range:") { |
| 473 | + let (_, bytes) = line.split_once('=').unwrap(); |
| 474 | + let (start, end) = bytes.split_once('-').unwrap(); |
| 475 | + let start = start.parse().unwrap_or(0); |
| 476 | + let end = end.parse().unwrap_or(0); |
| 477 | + range = Some(start..end) |
| 478 | + } |
| 479 | + // On `\r\n\r\n` (empty line) stop reading the request |
| 480 | + if line.is_empty() { |
| 481 | + break; |
| 482 | + } |
| 483 | + } |
| 484 | + |
| 485 | + range |
| 486 | +} |
0 commit comments