Skip to content

Commit d9dfd2d

Browse files
committed
Refactor to allow reinstating resource manager
1 parent f809321 commit d9dfd2d

File tree

2 files changed

+19
-58
lines changed

2 files changed

+19
-58
lines changed

src/app.rs

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use bytes::Bytes;
2525
use cached::{proc_macro::io_cached, stores::DiskCacheBuilder};
2626

2727
use std::sync::Arc;
28-
use tokio::sync::SemaphorePermit;
2928
use tower::Layer;
3029
use tower::ServiceBuilder;
3130
use tower_http::normalize_path::NormalizePathLayer;
@@ -168,11 +167,7 @@ async fn schema() -> &'static str {
168167
///
169168
/// * `client`: S3 client object
170169
/// * `request_data`: RequestData object for the request
171-
#[tracing::instrument(
172-
level = "DEBUG",
173-
// skip(client, request_data, resource_manager, mem_permits)
174-
skip(client, request_data)
175-
)]
170+
#[tracing::instrument(level = "DEBUG", skip(client, request_data, resource_manager))]
176171
#[io_cached(
177172
map_error = r##"|e| ActiveStorageError::CacheError{ error: format!("{:?}", e) }"##,
178173
disk = true,
@@ -183,11 +178,15 @@ async fn schema() -> &'static str {
183178
async fn download_object<'a>(
184179
client: &s3_client::S3Client,
185180
request_data: &models::RequestData,
186-
// resource_manager: &'a ResourceManager,
187-
// mem_permits: &mut Option<SemaphorePermit<'a>>,
181+
resource_manager: &ResourceManager,
188182
) -> Result<Bytes, ActiveStorageError> {
183+
// If we're given a size in the request data then use this to
184+
// get an initial guess at the required memory resources.
185+
let memory = request_data.size.unwrap_or(0);
186+
let mut mem_permits = resource_manager.memory(memory).await?;
187+
189188
let range = s3_client::get_range(request_data.offset, request_data.size);
190-
// let _conn_permits = resource_manager.s3_connection().await?;
189+
let _conn_permits = resource_manager.s3_connection().await?;
191190

192191
// Increment the prometheus metric for cache misses
193192
LOCAL_CACHE_MISSES.with_label_values(&["disk"]).inc();
@@ -197,8 +196,8 @@ async fn download_object<'a>(
197196
&request_data.bucket,
198197
&request_data.object,
199198
range,
200-
// resource_manager,
201-
// mem_permits,
199+
resource_manager,
200+
&mut mem_permits,
202201
)
203202
.await
204203
}
@@ -222,8 +221,6 @@ async fn operation_handler<T: operation::Operation>(
222221
auth: Option<TypedHeader<Authorization<Basic>>>,
223222
ValidatedJson(request_data): ValidatedJson<models::RequestData>,
224223
) -> Result<models::Response, ActiveStorageError> {
225-
let memory = request_data.size.unwrap_or(0);
226-
let mut _mem_permits = state.resource_manager.memory(memory).await?;
227224
let credentials = if let Some(TypedHeader(auth)) = auth {
228225
s3_client::S3Credentials::access_key(auth.username(), auth.password())
229226
} else {
@@ -234,14 +231,9 @@ async fn operation_handler<T: operation::Operation>(
234231
.get(&request_data.source, credentials)
235232
.instrument(tracing::Span::current())
236233
.await;
237-
let data = download_object(
238-
&s3_client,
239-
&request_data,
240-
// &state.resource_manager,
241-
// &mut _mem_permits,
242-
)
243-
.instrument(tracing::Span::current())
244-
.await?;
234+
let data = download_object(&s3_client, &request_data, &state.resource_manager)
235+
.instrument(tracing::Span::current())
236+
.await?;
245237
// All remaining work is synchronous. If the use_rayon argument was specified, delegate to the
246238
// Rayon thread pool. Otherwise, execute as normal using Tokio.
247239
if state.args.use_rayon {

src/s3_client.rs

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ impl S3Client {
156156
bucket: &str,
157157
key: &str,
158158
range: Option<String>,
159-
// resource_manager: &'a ResourceManager,
160-
// mem_permits: &mut Option<SemaphorePermit<'a>>,
159+
resource_manager: &'a ResourceManager,
160+
mem_permits: &mut Option<SemaphorePermit<'a>>,
161161
) -> Result<Bytes, ActiveStorageError> {
162162
let mut response = self
163163
.client
@@ -175,9 +175,10 @@ impl S3Client {
175175
.try_into()?;
176176

177177
// FIXME: how to account for compressed data?
178-
// if mem_permits.is_none() {
179-
// *mem_permits = resource_manager.memory(content_length).await?;
180-
// };
178+
if mem_permits.is_none() || mem_permits.as_ref().unwrap().num_permits() == 0 {
179+
*mem_permits = resource_manager.memory(content_length).await?;
180+
};
181+
181182
// The data returned by the S3 client does not have any alignment guarantees. In order to
182183
// reinterpret the data as an array of numbers with a higher alignment than 1, we need to
183184
// return the data in Bytes object in which the underlying data has a higher alignment.
@@ -225,40 +226,8 @@ pub fn get_range(offset: Option<usize>, size: Option<usize>) -> Option<String> {
225226
#[cfg(test)]
226227
mod tests {
227228
use super::*;
228-
use cached::{proc_macro::io_cached, stores::DiskCacheBuilder};
229229
use url::Url;
230230

231-
// #[cached(
232-
// ty = "SizedCache<String, String>",
233-
// create = "{ SizedCache::with_size(100) }",
234-
// convert = r#"{ format!("{}{}", a, b) }"#
235-
// )]
236-
// fn cache_test(a: &str, b: &str) -> String {
237-
// format!("{} - {}", a, b)
238-
// }
239-
240-
#[io_cached(
241-
map_error = r##"|e| ActiveStorageError::CacheError{ error: format!("{:?}", e) }"##,
242-
disk = true,
243-
create = r##"{ DiskCacheBuilder::new("test-cache").set_disk_directory("./").build().expect("valid disk cache builder") }"##,
244-
key = "String",
245-
convert = r##"{ format!("{}:{}", a, b) }"##
246-
)]
247-
async fn cache_test(a: &str, b: &str) -> Result<String, ActiveStorageError> {
248-
println!("Function called");
249-
Ok(format!("{} - {}", a, b))
250-
}
251-
252-
#[tokio::test]
253-
async fn disk_cache() {
254-
// cache_test("a").unwrap();
255-
// cache_test("a").unwrap();
256-
// cache_test(1, 2).unwrap();
257-
// cache_test(1, 2).unwrap();
258-
cache_test("a", "b").await.unwrap();
259-
cache_test("a", "b").await.unwrap();
260-
}
261-
262231
fn make_access_key() -> S3Credentials {
263232
S3Credentials::access_key("user", "password")
264233
}

0 commit comments

Comments
 (0)