Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 237 additions & 7 deletions dragonfly-client-backend/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ use dragonfly_client_core::{
use dragonfly_client_util::tls::NoVerifier;
use futures::TryStreamExt;
use http::header::{
HeaderName, HeaderValue, CONTENT_LENGTH, LOCATION, RANGE, TRANSFER_ENCODING, USER_AGENT,
HeaderName, HeaderValue, CONTENT_LENGTH, CONTENT_RANGE, LOCATION, RANGE, TRANSFER_ENCODING,
USER_AGENT,
};
use lru::LruCache;
use reqwest::header::HeaderMap;
Expand Down Expand Up @@ -551,10 +552,7 @@ impl Backend for HTTP {

let response_status_code = response.status();
let response_header = response.headers().clone();
let content_length = match response_header.get(CONTENT_LENGTH) {
Some(content_length) => content_length.to_str()?.parse::<u64>().ok(),
None => response.content_length(),
};
let content_length = content_length_from_response(&response_header, &response)?;

debug!(
"stat response {} {}: {:?} {:?} {:?}",
Expand Down Expand Up @@ -824,6 +822,26 @@ fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &Url)
}
}

fn content_length_from_response(
headers: &HeaderMap,
response: &reqwest::Response,
) -> Result<Option<u64>> {
if headers.contains_key(CONTENT_RANGE) {
return Ok(content_range_total(headers));
}

match headers.get(CONTENT_LENGTH) {
Some(content_length) => Ok(content_length.to_str()?.parse::<u64>().ok()),
None => Ok(response.content_length()),
}
}

fn content_range_total(headers: &HeaderMap) -> Option<u64> {
let content_range = headers.get(CONTENT_RANGE)?.to_str().ok()?;
let (_, total) = content_range.rsplit_once('/')?;
total.parse::<u64>().ok()
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -841,10 +859,26 @@ mod tests {
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;
use wiremock::{
matchers::{method, path},
Mock, ResponseTemplate,
matchers::{header, method, path},
Mock, Request, Respond, ResponseTemplate,
};

#[test]
fn test_content_range_total() {
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_RANGE,
HeaderValue::from_static("bytes 5242880-10485759/12582912"),
);
assert_eq!(content_range_total(&headers), Some(12_582_912));

headers.insert(CONTENT_RANGE, HeaderValue::from_static("bytes 0-1023/*"));
assert_eq!(content_range_total(&headers), None);

headers.insert(CONTENT_RANGE, HeaderValue::from_static("invalid"));
assert_eq!(content_range_total(&headers), None);
}

// Generate the certificate and private key by script(`scripts/generate_certs.sh`).
const SERVER_CERT: &str = r#"""
-----BEGIN CERTIFICATE-----
Expand Down Expand Up @@ -1634,4 +1668,200 @@ LJ8gCHKBOJy9dW62DcRWw6zzlTtt9y18/Btx0Hpawg==
assert_eq!(response.http_status_code, Some(StatusCode::OK));
assert_eq!(response.text().await.unwrap(), "target content");
}

struct AssertRangeResponder {
expected: &'static str,
body: &'static [u8],
content_range: &'static str,
}

impl Respond for AssertRangeResponder {
fn respond(&self, request: &Request) -> ResponseTemplate {
let ranges: Vec<String> = request
.headers
.get_all("Range")
.iter()
.map(|v| v.to_str().unwrap_or("<non-ascii>").to_string())
.collect();

if ranges.len() == 1 && ranges[0] == self.expected {
ResponseTemplate::new(206)
.insert_header("Content-Range", self.content_range)
.insert_header("Content-Length", self.body.len().to_string())
.set_body_bytes(self.body)
} else {
ResponseTemplate::new(400).set_body_string(format!(
"expected Range={:?} but saw {:?}",
self.expected, ranges
))
}
}
}

#[tokio::test]
async fn should_forward_caller_range_when_getrequest_range_is_none() {
let server = wiremock::MockServer::start().await;
Mock::given(method("GET"))
.and(path("/object.bin"))
.and(header("Range", "bytes=5242880-10485759"))
.respond_with(
ResponseTemplate::new(206)
.insert_header("Content-Range", "bytes 5242880-10485759/12582912")
.insert_header("Content-Length", "5242880")
.set_body_bytes(vec![0u8; 5_242_880]),
)
.expect(1)
.mount(&server)
.await;

let mut http_header = HeaderMap::new();
http_header.insert(RANGE, HeaderValue::from_static("bytes=5242880-10485759"));
http_header.insert(
reqwest::header::AUTHORIZATION,
HeaderValue::from_static(
"AWS4-HMAC-SHA256 Credential=AKIA/20260101/us-west-2/s3/aws4_request, \
SignedHeaders=host;range;x-amz-content-sha256;x-amz-date, Signature=deadbeef",
),
);

let resp = HTTP::new(HTTP_SCHEME, None, 1, true, Duration::from_secs(600), true)
.unwrap()
.get(GetRequest {
task_id: "test".to_string(),
piece_id: "test".to_string(),
url: format!("{}/object.bin", server.uri()),
range: None,
http_header: Some(http_header),
timeout: Duration::from_secs(5),
client_cert: None,
object_storage: None,
hdfs: None,
hugging_face: None,
model_scope: None,
})
.await
.unwrap();

assert_eq!(resp.http_status_code, Some(StatusCode::PARTIAL_CONTENT));
}

#[tokio::test]
async fn should_overwrite_caller_range_when_getrequest_range_is_some() {
let server = wiremock::MockServer::start().await;
let responder = AssertRangeResponder {
expected: "bytes=0-1023",
body: &[0u8; 1024],
content_range: "bytes 0-1023/12582912",
};
Mock::given(method("GET"))
.and(path("/object.bin"))
.respond_with(responder)
.expect(1)
.mount(&server)
.await;

let mut http_header = HeaderMap::new();
http_header.insert(RANGE, HeaderValue::from_static("bytes=5242880-10485759"));

let resp = HTTP::new(HTTP_SCHEME, None, 1, true, Duration::from_secs(600), true)
.unwrap()
.get(GetRequest {
task_id: "test".to_string(),
piece_id: "test".to_string(),
url: format!("{}/object.bin", server.uri()),
range: Some(Range {
start: 0,
length: 1024,
}),
http_header: Some(http_header),
timeout: Duration::from_secs(5),
client_cert: None,
object_storage: None,
hdfs: None,
hugging_face: None,
model_scope: None,
})
.await
.unwrap();

assert_eq!(resp.http_status_code, Some(StatusCode::PARTIAL_CONTENT));
}

#[tokio::test]
async fn should_stat_returns_content_range_total_as_content_length() {
let server = wiremock::MockServer::start().await;
Mock::given(method("GET"))
.and(path("/object.bin"))
.and(header("Range", "bytes=5242880-10485759"))
.respond_with(
ResponseTemplate::new(206)
.insert_header("Content-Range", "bytes 5242880-10485759/12582912")
.insert_header("Content-Length", "5242880")
.set_body_bytes(vec![0u8; 5_242_880]),
)
.expect(1)
.mount(&server)
.await;

let mut http_header = HeaderMap::new();
http_header.insert(RANGE, HeaderValue::from_static("bytes=5242880-10485759"));

let resp = HTTP::new(HTTP_SCHEME, None, 1, true, Duration::from_secs(600), true)
.unwrap()
.stat(StatRequest {
task_id: "test".to_string(),
url: format!("{}/object.bin", server.uri()),
http_header: Some(http_header),
timeout: Duration::from_secs(5),
client_cert: None,
object_storage: None,
hdfs: None,
hugging_face: None,
model_scope: None,
})
.await
.unwrap();

assert!(resp.success);
assert_eq!(resp.content_length, Some(12_582_912));
}

#[tokio::test]
async fn should_stat_not_fallback_to_slice_length_when_content_range_total_is_unknown() {
let server = wiremock::MockServer::start().await;
Mock::given(method("GET"))
.and(path("/object.bin"))
.and(header("Range", "bytes=0-1023"))
.respond_with(
ResponseTemplate::new(206)
.insert_header("Content-Range", "bytes 0-1023/*")
.insert_header("Content-Length", "1024")
.set_body_bytes(vec![0u8; 1024]),
)
.expect(1)
.mount(&server)
.await;

let mut http_header = HeaderMap::new();
http_header.insert(RANGE, HeaderValue::from_static("bytes=0-1023"));

let resp = HTTP::new(HTTP_SCHEME, None, 1, true, Duration::from_secs(600), true)
.unwrap()
.stat(StatRequest {
task_id: "test".to_string(),
url: format!("{}/object.bin", server.uri()),
http_header: Some(http_header),
timeout: Duration::from_secs(5),
client_cert: None,
object_storage: None,
hdfs: None,
hugging_face: None,
model_scope: None,
})
.await
.unwrap();

assert!(resp.success);
assert_eq!(resp.content_length, None);
}
}
32 changes: 32 additions & 0 deletions dragonfly-client/src/grpc/dfdaemon_download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::dynconfig::Dynconfig;
use crate::grpc::block_list::{
BlockList, DownloadBlockListCheckParams, UploadBlockListCheckParams,
};
use crate::proxy::header as proxy_header;
use crate::resource::{persistent_cache_task, persistent_task, task};
use dragonfly_api::common::v2::{
CacheTask, PersistentCacheTask, PersistentTask, Priority, Task, TaskType,
Expand Down Expand Up @@ -342,6 +343,37 @@ impl DfdaemonDownload for DfdaemonDownloadServerHandler {
// If concurrent_piece_count is not set in the request, use the default value in the config.
download.concurrent_piece_count = Some(self.config.download.concurrent_piece_count);

download.request_header.retain(|key, _| {
!key.eq_ignore_ascii_case(
proxy_header::DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER,
)
});

if let Ok(parsed_header) = hashmap_to_headermap(&download.request_header) {
if proxy_header::range_header_is_signature_bound(
&parsed_header,
Some(download.url.as_str()),
) {
download.request_header.insert(
proxy_header::DRAGONFLY_PRESERVE_ORIGINAL_RANGE_FOR_SOURCE_HEADER.to_string(),
"true".to_string(),
);

let range_value = parsed_header
.get(reqwest::header::RANGE)
.and_then(|v| v.to_str().ok())
.unwrap_or_default();
let task_id_content = download
.content_for_calculating_task_id
.take()
.unwrap_or_else(|| {
format!("{}\n{}", download.url, download.piece_length.unwrap_or(0))
});
download.content_for_calculating_task_id =
Some(format!("{}\n{}", task_id_content, range_value));
}
}

// Generate the task id.
let task_id = self
.task
Expand Down
Loading
Loading