diff --git a/dragonfly-client-backend/src/http.rs b/dragonfly-client-backend/src/http.rs index 0a792172..596fdc69 100644 --- a/dragonfly-client-backend/src/http.rs +++ b/dragonfly-client-backend/src/http.rs @@ -59,7 +59,8 @@ use dragonfly_client_core::{ use dragonfly_client_util::tls::NoVerifier; use futures::TryStreamExt; use http::header::{ - HeaderName, HeaderValue, CONTENT_LENGTH, LOCATION, RANGE, TRANSFER_ENCODING, USER_AGENT, + HeaderName, HeaderValue, CONTENT_LENGTH, CONTENT_RANGE, LOCATION, RANGE, TRANSFER_ENCODING, + USER_AGENT, }; use lru::LruCache; use reqwest::header::HeaderMap; @@ -551,10 +552,7 @@ impl Backend for HTTP { let response_status_code = response.status(); let response_header = response.headers().clone(); - let content_length = match response_header.get(CONTENT_LENGTH) { - Some(content_length) => content_length.to_str()?.parse::().ok(), - None => response.content_length(), - }; + let content_length = content_length_from_response(&response_header, &response)?; debug!( "stat response {} {}: {:?} {:?} {:?}", @@ -824,6 +822,26 @@ fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &Url) } } +fn content_length_from_response( + headers: &HeaderMap, + response: &reqwest::Response, +) -> Result> { + if headers.contains_key(CONTENT_RANGE) { + return Ok(content_range_total(headers)); + } + + match headers.get(CONTENT_LENGTH) { + Some(content_length) => Ok(content_length.to_str()?.parse::().ok()), + None => Ok(response.content_length()), + } +} + +fn content_range_total(headers: &HeaderMap) -> Option { + let content_range = headers.get(CONTENT_RANGE)?.to_str().ok()?; + let (_, total) = content_range.rsplit_once('/')?; + total.parse::().ok() +} + #[cfg(test)] mod tests { use super::*; @@ -841,10 +859,26 @@ mod tests { use tokio_rustls::rustls::ServerConfig; use tokio_rustls::TlsAcceptor; use wiremock::{ - matchers::{method, path}, - Mock, ResponseTemplate, + matchers::{header, method, path}, + Mock, Request, Respond, ResponseTemplate, }; + #[test] + fn test_content_range_total() { + let mut headers = HeaderMap::new(); + headers.insert( + CONTENT_RANGE, + HeaderValue::from_static("bytes 5242880-10485759/12582912"), + ); + assert_eq!(content_range_total(&headers), Some(12_582_912)); + + headers.insert(CONTENT_RANGE, HeaderValue::from_static("bytes 0-1023/*")); + assert_eq!(content_range_total(&headers), None); + + headers.insert(CONTENT_RANGE, HeaderValue::from_static("invalid")); + assert_eq!(content_range_total(&headers), None); + } + // Generate the certificate and private key by script(`scripts/generate_certs.sh`). const SERVER_CERT: &str = r#""" -----BEGIN CERTIFICATE----- @@ -1634,4 +1668,200 @@ LJ8gCHKBOJy9dW62DcRWw6zzlTtt9y18/Btx0Hpawg== assert_eq!(response.http_status_code, Some(StatusCode::OK)); assert_eq!(response.text().await.unwrap(), "target content"); } + + struct AssertRangeResponder { + expected: &'static str, + body: &'static [u8], + content_range: &'static str, + } + + impl Respond for AssertRangeResponder { + fn respond(&self, request: &Request) -> ResponseTemplate { + let ranges: Vec = request + .headers + .get_all("Range") + .iter() + .map(|v| v.to_str().unwrap_or("").to_string()) + .collect(); + + if ranges.len() == 1 && ranges[0] == self.expected { + ResponseTemplate::new(206) + .insert_header("Content-Range", self.content_range) + .insert_header("Content-Length", self.body.len().to_string()) + .set_body_bytes(self.body) + } else { + ResponseTemplate::new(400).set_body_string(format!( + "expected Range={:?} but saw {:?}", + self.expected, ranges + )) + } + } + } + + #[tokio::test] + async fn should_forward_caller_range_when_getrequest_range_is_none() { + let server = wiremock::MockServer::start().await; + Mock::given(method("GET")) + .and(path("/object.bin")) + .and(header("Range", "bytes=5242880-10485759")) + .respond_with( + ResponseTemplate::new(206) + .insert_header("Content-Range", "bytes 5242880-10485759/12582912") + .insert_header("Content-Length", "5242880") + .set_body_bytes(vec![0u8; 5_242_880]), + ) + .expect(1) + .mount(&server) + .await; + + let mut http_header = HeaderMap::new(); + http_header.insert(RANGE, HeaderValue::from_static("bytes=5242880-10485759")); + http_header.insert( + reqwest::header::AUTHORIZATION, + HeaderValue::from_static( + "AWS4-HMAC-SHA256 Credential=AKIA/20260101/us-west-2/s3/aws4_request, \ + SignedHeaders=host;range;x-amz-content-sha256;x-amz-date, Signature=deadbeef", + ), + ); + + let resp = HTTP::new(HTTP_SCHEME, None, 1, true, Duration::from_secs(600), true) + .unwrap() + .get(GetRequest { + task_id: "test".to_string(), + piece_id: "test".to_string(), + url: format!("{}/object.bin", server.uri()), + range: None, + http_header: Some(http_header), + timeout: Duration::from_secs(5), + client_cert: None, + object_storage: None, + hdfs: None, + hugging_face: None, + model_scope: None, + }) + .await + .unwrap(); + + assert_eq!(resp.http_status_code, Some(StatusCode::PARTIAL_CONTENT)); + } + + #[tokio::test] + async fn should_overwrite_caller_range_when_getrequest_range_is_some() { + let server = wiremock::MockServer::start().await; + let responder = AssertRangeResponder { + expected: "bytes=0-1023", + body: &[0u8; 1024], + content_range: "bytes 0-1023/12582912", + }; + Mock::given(method("GET")) + .and(path("/object.bin")) + .respond_with(responder) + .expect(1) + .mount(&server) + .await; + + let mut http_header = HeaderMap::new(); + http_header.insert(RANGE, HeaderValue::from_static("bytes=5242880-10485759")); + + let resp = HTTP::new(HTTP_SCHEME, None, 1, true, Duration::from_secs(600), true) + .unwrap() + .get(GetRequest { + task_id: "test".to_string(), + piece_id: "test".to_string(), + url: format!("{}/object.bin", server.uri()), + range: Some(Range { + start: 0, + length: 1024, + }), + http_header: Some(http_header), + timeout: Duration::from_secs(5), + client_cert: None, + object_storage: None, + hdfs: None, + hugging_face: None, + model_scope: None, + }) + .await + .unwrap(); + + assert_eq!(resp.http_status_code, Some(StatusCode::PARTIAL_CONTENT)); + } + + #[tokio::test] + async fn should_stat_returns_content_range_total_as_content_length() { + let server = wiremock::MockServer::start().await; + Mock::given(method("GET")) + .and(path("/object.bin")) + .and(header("Range", "bytes=5242880-10485759")) + .respond_with( + ResponseTemplate::new(206) + .insert_header("Content-Range", "bytes 5242880-10485759/12582912") + .insert_header("Content-Length", "5242880") + .set_body_bytes(vec![0u8; 5_242_880]), + ) + .expect(1) + .mount(&server) + .await; + + let mut http_header = HeaderMap::new(); + http_header.insert(RANGE, HeaderValue::from_static("bytes=5242880-10485759")); + + let resp = HTTP::new(HTTP_SCHEME, None, 1, true, Duration::from_secs(600), true) + .unwrap() + .stat(StatRequest { + task_id: "test".to_string(), + url: format!("{}/object.bin", server.uri()), + http_header: Some(http_header), + timeout: Duration::from_secs(5), + client_cert: None, + object_storage: None, + hdfs: None, + hugging_face: None, + model_scope: None, + }) + .await + .unwrap(); + + assert!(resp.success); + assert_eq!(resp.content_length, Some(12_582_912)); + } + + #[tokio::test] + async fn should_stat_not_fallback_to_slice_length_when_content_range_total_is_unknown() { + let server = wiremock::MockServer::start().await; + Mock::given(method("GET")) + .and(path("/object.bin")) + .and(header("Range", "bytes=0-1023")) + .respond_with( + ResponseTemplate::new(206) + .insert_header("Content-Range", "bytes 0-1023/*") + .insert_header("Content-Length", "1024") + .set_body_bytes(vec![0u8; 1024]), + ) + .expect(1) + .mount(&server) + .await; + + let mut http_header = HeaderMap::new(); + http_header.insert(RANGE, HeaderValue::from_static("bytes=0-1023")); + + let resp = HTTP::new(HTTP_SCHEME, None, 1, true, Duration::from_secs(600), true) + .unwrap() + .stat(StatRequest { + task_id: "test".to_string(), + url: format!("{}/object.bin", server.uri()), + http_header: Some(http_header), + timeout: Duration::from_secs(5), + client_cert: None, + object_storage: None, + hdfs: None, + hugging_face: None, + model_scope: None, + }) + .await + .unwrap(); + + assert!(resp.success); + assert_eq!(resp.content_length, None); + } } diff --git a/dragonfly-client/src/grpc/dfdaemon_download.rs b/dragonfly-client/src/grpc/dfdaemon_download.rs index 53b47c6e..43815b55 100644 --- a/dragonfly-client/src/grpc/dfdaemon_download.rs +++ b/dragonfly-client/src/grpc/dfdaemon_download.rs @@ -18,6 +18,7 @@ use crate::dynconfig::Dynconfig; use crate::grpc::block_list::{ BlockList, DownloadBlockListCheckParams, UploadBlockListCheckParams, }; +use crate::proxy::header as proxy_header; use crate::resource::{persistent_cache_task, persistent_task, task}; use dragonfly_api::common::v2::{ CacheTask, PersistentCacheTask, PersistentTask, Priority, Task, TaskType, @@ -342,6 +343,37 @@ impl DfdaemonDownload for DfdaemonDownloadServerHandler { // If concurrent_piece_count is not set in the request, use the default value in the config. download.concurrent_piece_count = Some(self.config.download.concurrent_piece_count); + download.request_header.retain(|key, _| { + !key.eq_ignore_ascii_case( + proxy_header::DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER, + ) + }); + + if let Ok(parsed_header) = hashmap_to_headermap(&download.request_header) { + if proxy_header::range_header_is_signature_bound( + &parsed_header, + Some(download.url.as_str()), + ) { + download.request_header.insert( + proxy_header::DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER.to_string(), + "true".to_string(), + ); + + let range_value = parsed_header + .get(reqwest::header::RANGE) + .and_then(|v| v.to_str().ok()) + .unwrap_or_default(); + let task_id_content = download + .content_for_calculating_task_id + .take() + .unwrap_or_else(|| { + format!("{}\n{}", download.url, download.piece_length.unwrap_or(0)) + }); + download.content_for_calculating_task_id = + Some(format!("{}\n{}", task_id_content, range_value)); + } + } + // Generate the task id. let task_id = self .task diff --git a/dragonfly-client/src/proxy/header.rs b/dragonfly-client/src/proxy/header.rs index 1253d064..c8eaec5f 100644 --- a/dragonfly-client/src/proxy/header.rs +++ b/dragonfly-client/src/proxy/header.rs @@ -19,6 +19,7 @@ use dragonfly_api::common::v2::Priority; use reqwest::header::HeaderMap; use std::{fmt, str::FromStr}; use tracing::error; +use url::Url; /// DRAGONFLY_TAG_HEADER is the header key of tag in http request. pub const DRAGONFLY_TAG_HEADER: &str = "X-Dragonfly-Tag"; @@ -74,6 +75,9 @@ pub const DRAGONFLY_FORCE_HARD_LINK_HEADER: &str = "X-Dragonfly-Force-Hard-Link" /// to 4mib, for example: 4mib, 1gib pub const DRAGONFLY_PIECE_LENGTH_HEADER: &str = "X-Dragonfly-Piece-Length"; +pub const DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER: &str = + "X-Dragonfly-Preserve-Original-Range-For-Source"; + /// DRAGONFLY_CONTENT_FOR_CALCULATING_TASK_ID_HEADER is the header key of content for calculating task id. /// If DRAGONFLY_CONTENT_FOR_CALCULATING_TASK_ID_HEADER is set, use its value to calculate the task ID. /// Otherwise, calculate the task ID based on `url`, `piece_length`, `tag`, `application`, and `filtered_query_params`. @@ -295,6 +299,67 @@ pub fn get_piece_length(header: &HeaderMap) -> Option { } } +pub fn get_preserve_original_range_for_source(header: &HeaderMap) -> bool { + match header.get(DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER) { + Some(value) => match value.to_str() { + Ok(value) => value.eq_ignore_ascii_case("true"), + Err(err) => { + error!( + "get preserve original range for source from header failed: {}", + err + ); + false + } + }, + None => false, + } +} + +pub fn range_header_is_signature_bound(header: &HeaderMap, url: Option<&str>) -> bool { + if !header.contains_key(reqwest::header::RANGE) { + return false; + } + + if header + .get(reqwest::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + .map(|value| sigv4_signed_headers_contains(value, "range")) + .unwrap_or(false) + { + return true; + } + + url.and_then(|value| Url::parse(value).ok()) + .map(|parsed_url| presigned_url_signs_header(&parsed_url, "range")) + .unwrap_or(false) +} + +fn sigv4_signed_headers_contains(authorization: &str, header_name: &str) -> bool { + authorization + .split(',') + .find_map(|part| part.trim().strip_prefix("SignedHeaders=")) + .map(|signed_headers| signed_headers_contains_header(signed_headers, header_name)) + .unwrap_or(false) +} + +fn presigned_url_signs_header(url: &Url, header_name: &str) -> bool { + url.query_pairs() + .find_map(|(key, value)| { + if key.eq_ignore_ascii_case("X-Amz-SignedHeaders") { + Some(signed_headers_contains_header(value.as_ref(), header_name)) + } else { + None + } + }) + .unwrap_or(false) +} + +fn signed_headers_contains_header(signed_headers: &str, header_name: &str) -> bool { + signed_headers + .split(';') + .any(|value| value.trim().eq_ignore_ascii_case(header_name)) +} + /// Get X-Dragonfly-Content-For-Calculating-Task-ID header value to determine the content for /// calculating task ID. pub fn get_content_for_calculating_task_id(header: &HeaderMap) -> Option { @@ -517,4 +582,125 @@ mod tests { assert!(get_enable_task_id_based_blob_digest(&empty_headers, true)); assert!(!get_enable_task_id_based_blob_digest(&empty_headers, false)); } + + #[test] + fn test_get_preserve_original_range_for_source() { + let mut headers = HeaderMap::new(); + headers.insert( + DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER, + HeaderValue::from_static("true"), + ); + assert!(get_preserve_original_range_for_source(&headers)); + + headers.insert( + DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER, + HeaderValue::from_static("TRUE"), + ); + assert!(get_preserve_original_range_for_source(&headers)); + + headers.insert( + DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER, + HeaderValue::from_static("false"), + ); + assert!(!get_preserve_original_range_for_source(&headers)); + + let empty_headers = HeaderMap::new(); + assert!(!get_preserve_original_range_for_source(&empty_headers)); + } + + #[test] + fn test_range_header_is_signature_bound_sigv4() { + let mut headers = HeaderMap::new(); + headers.insert( + reqwest::header::RANGE, + HeaderValue::from_static("bytes=0-1023"), + ); + headers.insert( + reqwest::header::AUTHORIZATION, + HeaderValue::from_static( + "AWS4-HMAC-SHA256 \ + Credential=AKIA/20260101/us-west-2/s3/aws4_request, \ + SignedHeaders=host;range;x-amz-content-sha256;x-amz-date, \ + Signature=deadbeef", + ), + ); + assert!(range_header_is_signature_bound(&headers, None)); + + headers.insert( + reqwest::header::AUTHORIZATION, + HeaderValue::from_static( + "AWS4-HMAC-SHA256 Credential=x, SignedHeaders=HOST;Range;x-amz-date, Signature=y", + ), + ); + assert!(range_header_is_signature_bound(&headers, None)); + } + + #[test] + fn test_range_header_is_signature_bound_not_signing_range() { + let mut headers = HeaderMap::new(); + headers.insert( + reqwest::header::RANGE, + HeaderValue::from_static("bytes=0-1023"), + ); + headers.insert( + reqwest::header::AUTHORIZATION, + HeaderValue::from_static( + "AWS4-HMAC-SHA256 Credential=x, SignedHeaders=host;x-amz-date, Signature=y", + ), + ); + assert!(!range_header_is_signature_bound(&headers, None)); + } + + #[test] + fn test_range_header_is_signature_bound_no_range_header() { + let mut headers = HeaderMap::new(); + headers.insert( + reqwest::header::AUTHORIZATION, + HeaderValue::from_static( + "AWS4-HMAC-SHA256 Credential=x, SignedHeaders=host;range;x-amz-date, Signature=y", + ), + ); + assert!(!range_header_is_signature_bound(&headers, None)); + } + + #[test] + fn test_range_header_is_signature_bound_presigned_url() { + let mut headers = HeaderMap::new(); + headers.insert( + reqwest::header::RANGE, + HeaderValue::from_static("bytes=0-1023"), + ); + assert!(range_header_is_signature_bound( + &headers, + Some( + "https://bucket.s3.us-west-2.amazonaws.com/k.bin\ + ?X-Amz-Algorithm=AWS4-HMAC-SHA256\ + &X-Amz-SignedHeaders=host%3Brange\ + &X-Amz-Signature=deadbeef", + ), + )); + + assert!(!range_header_is_signature_bound( + &headers, + Some( + "https://bucket.s3.us-west-2.amazonaws.com/k.bin\ + ?X-Amz-SignedHeaders=host\ + &X-Amz-Signature=deadbeef", + ), + )); + } + + #[test] + fn test_signed_headers_contains_header() { + assert!(signed_headers_contains_header( + "host;range;x-amz-date", + "range" + )); + assert!(signed_headers_contains_header( + "HOST;Range;X-Amz-Date", + "range" + )); + assert!(!signed_headers_contains_header("host;x-amz-date", "range")); + assert!(!signed_headers_contains_header("", "range")); + } } diff --git a/dragonfly-client/src/resource/piece.rs b/dragonfly-client/src/resource/piece.rs index 851dfe29..cff9ed79 100644 --- a/dragonfly-client/src/resource/piece.rs +++ b/dragonfly-client/src/resource/piece.rs @@ -475,6 +475,42 @@ impl Piece { hdfs: Option, hugging_face: Option, model_scope: Option, + ) -> Result { + self.download_from_source_with_options( + piece_id, + task_id, + number, + url, + offset, + length, + request_header, + is_prefetch, + false, + object_storage, + hdfs, + hugging_face, + model_scope, + ) + .await + } + + #[allow(clippy::too_many_arguments)] + #[instrument(skip_all, fields(piece_id))] + pub async fn download_from_source_with_options( + &self, + piece_id: &str, + task_id: &str, + number: u32, + url: &str, + offset: u64, + length: u64, + request_header: HeaderMap, + is_prefetch: bool, + preserve_original_range_for_source: bool, + object_storage: Option, + hdfs: Option, + hugging_face: Option, + model_scope: Option, ) -> Result { // Span record the piece_id. Span::current().record("piece_id", piece_id); @@ -537,10 +573,14 @@ impl Piece { task_id: task_id.to_string(), piece_id: piece_id.to_string(), url: url.to_string(), - range: Some(Range { - start: offset, - length, - }), + range: if preserve_original_range_for_source { + None + } else { + Some(Range { + start: offset, + length, + }) + }, http_header: Some(request_header), timeout: self.config.download.piece_timeout, client_cert: None, diff --git a/dragonfly-client/src/resource/task.rs b/dragonfly-client/src/resource/task.rs index 0a0dcf08..b7313ff0 100644 --- a/dragonfly-client/src/resource/task.rs +++ b/dragonfly-client/src/resource/task.rs @@ -15,9 +15,11 @@ */ use crate::grpc::{scheduler::SchedulerClient, REQUEST_TIMEOUT}; +use crate::proxy::header as proxy_header; use crate::resource::parent_selector::ParentSelector; +use chrono::Utc; use dragonfly_api::common::v2::{ - Download, Hdfs, HuggingFace, ModelScope, ObjectStorage, Peer, Piece, Task as CommonTask, + Download, Hdfs, HuggingFace, ModelScope, ObjectStorage, Peer, Piece, Range, Task as CommonTask, TrafficType, }; use dragonfly_api::dfdaemon::{ @@ -180,10 +182,13 @@ impl Task { error!("convert header: {}", err); })?; - // Remove the range header to prevent the server from - // returning a 206 partial content and returning - // a 200 full content. - request_header.remove(reqwest::header::RANGE); + let preserve_original_range_for_source = + proxy_header::get_preserve_original_range_for_source(&request_header); + request_header.remove(proxy_header::DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER); + + if !preserve_original_range_for_source { + request_header.remove(reqwest::header::RANGE); + } // Head the url to get the content length. let backend = self.backend_factory.build(request.url.as_str())?; @@ -244,17 +249,19 @@ impl Task { None => return Err(Error::InvalidContentLength), }; - let piece_length = match request.piece_length { - Some(piece_length) => self - .piece - .calculate_piece_length(piece::PieceLengthStrategy::FixedPieceLength(piece_length)), - None => { - self.piece - .calculate_piece_length(piece::PieceLengthStrategy::OptimizeByFileLength( - content_length, - )) - } - }; + let piece_length = + if preserve_original_range_for_source { + content_length + } else { + match request.piece_length { + Some(piece_length) => self.piece.calculate_piece_length( + piece::PieceLengthStrategy::FixedPieceLength(piece_length), + ), + None => self.piece.calculate_piece_length( + piece::PieceLengthStrategy::OptimizeByFileLength(content_length), + ), + } + }; // If the task is not finished, check if the storage has enough space to // store the task. @@ -356,17 +363,18 @@ impl Task { }; // Calculate the interested pieces to download. - let interested_pieces = - match self - .piece - .calculate_interested(piece_length, content_length, request.range) - { - Ok(interested_pieces) => interested_pieces, - Err(err) => { - error!("calculate interested pieces error: {:?}", err); - return Err(err); - } - }; + let interested_pieces = match Self::calculate_interested_pieces( + &self.piece, + piece_length, + content_length, + &request, + ) { + Ok(interested_pieces) => interested_pieces, + Err(err) => { + error!("calculate interested pieces error: {:?}", err); + return Err(err); + } + }; debug!( "interested pieces: {:?}", interested_pieces @@ -1317,10 +1325,14 @@ impl Task { let task_id = task.id.as_str(); // Convert the header. - let request_header: HeaderMap = (&request.request_header) + let mut request_header: HeaderMap = (&request.request_header) .try_into() .or_err(ErrorType::ParseError)?; + let preserve_original_range_for_source = + proxy_header::get_preserve_original_range_for_source(&request_header); + request_header.remove(proxy_header::DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER); + // Initialize the finished pieces. let mut finished_pieces: Vec = Vec::new(); @@ -1340,6 +1352,7 @@ impl Task { length: u64, request_header: HeaderMap, is_prefetch: bool, + preserve_original_range_for_source: bool, need_piece_content: bool, piece_manager: Arc, download_progress_tx: Sender>, @@ -1353,7 +1366,7 @@ impl Task { info!("start to download piece {} from source", piece_id); let metadata = piece_manager - .download_from_source( + .download_from_source_with_options( piece_id.as_str(), task_id.as_str(), number, @@ -1362,6 +1375,7 @@ impl Task { length, request_header, is_prefetch, + preserve_original_range_for_source, object_storage, hdfs, hugging_face, @@ -1481,6 +1495,7 @@ impl Task { interested_piece.length, request_header, request.is_prefetch, + preserve_original_range_for_source, request.need_piece_content, piece_manager, download_progress_tx, @@ -1723,10 +1738,14 @@ impl Task { let task_id = task.id.as_str(); // Convert the header. - let request_header: HeaderMap = (&request.request_header) + let mut request_header: HeaderMap = (&request.request_header) .try_into() .or_err(ErrorType::ParseError)?; + let preserve_original_range_for_source = + proxy_header::get_preserve_original_range_for_source(&request_header); + request_header.remove(proxy_header::DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER); + // Initialize the finished pieces. let mut finished_pieces: Vec = Vec::new(); @@ -1746,6 +1765,7 @@ impl Task { length: u64, request_header: HeaderMap, is_prefetch: bool, + preserve_original_range_for_source: bool, need_piece_content: bool, piece_manager: Arc, download_progress_tx: Sender>, @@ -1758,7 +1778,7 @@ impl Task { info!("start to download piece {} from source", piece_id); let metadata = piece_manager - .download_from_source( + .download_from_source_with_options( piece_id.as_str(), task_id.as_str(), number, @@ -1767,6 +1787,7 @@ impl Task { length, request_header, is_prefetch, + preserve_original_range_for_source, object_storage, hdfs, hugging_face, @@ -1863,6 +1884,7 @@ impl Task { interested_piece.length, request_header, request.is_prefetch, + preserve_original_range_for_source, request.need_piece_content, piece_manager, download_progress_tx, @@ -1910,6 +1932,44 @@ impl Task { return Ok(finished_pieces); } + fn calculate_interested_pieces( + piece_manager: &piece::Piece, + piece_length: u64, + content_length: u64, + request: &Download, + ) -> ClientResult> { + if Self::should_preserve_original_range_for_source(request) { + if let Some(range) = request.range { + return Ok(vec![Self::signed_range_piece(range)]); + } + } + + piece_manager.calculate_interested(piece_length, content_length, request.range) + } + + fn should_preserve_original_range_for_source(request: &Download) -> bool { + request.request_header.iter().any(|(key, value)| { + key.eq_ignore_ascii_case( + proxy_header::DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER, + ) && value.eq_ignore_ascii_case("true") + }) + } + + fn signed_range_piece(range: Range) -> metadata::Piece { + metadata::Piece { + number: 0, + offset: range.start, + length: range.length, + digest: "".to_string(), + parent_id: None, + uploading_count: 0, + uploaded_count: 0, + updated_at: Utc::now().naive_utc(), + created_at: Utc::now().naive_utc(), + finished_at: None, + } + } + /// stat_task returns the task metadata from scheduler. #[instrument(skip_all)] pub async fn stat(&self, task_id: &str, host_id: &str) -> ClientResult { @@ -2031,6 +2091,7 @@ impl Task { #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; use std::sync::Arc; use tempfile::tempdir; @@ -2077,4 +2138,32 @@ mod tests { let task = storage.get_task(task_id).unwrap(); assert!(task.is_none(), "task should be deleted"); } + + #[test] + fn test_should_preserve_original_range_for_source() { + let request = Download { + request_header: HashMap::from([( + proxy_header::DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER.to_string(), + "true".to_string(), + )]), + ..Default::default() + }; + assert!(Task::should_preserve_original_range_for_source(&request)); + + let request = Download::default(); + assert!(!Task::should_preserve_original_range_for_source(&request)); + } + + #[test] + fn test_signed_range_piece_uses_original_offset_and_single_piece_number() { + let range = Range { + start: 5 * 1024 * 1024, + length: 5 * 1024 * 1024, + }; + + let piece = Task::signed_range_piece(range); + assert_eq!(piece.offset, range.start); + assert_eq!(piece.length, range.length); + assert_eq!(piece.number, 0); + } }