Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ opendal = { version = "0.51.0", features = ["services-http"] }
tokio = { version = "1.41.1", features = ["rt-multi-thread"] }
zarrs_opendal = "0.5.0"
zarrs_metadata = "0.3.3" # require recent zarr-python compatibility fixes (remove with zarrs 0.20)
itertools = "0.9.0"

[profile.release]
lto = true
54 changes: 44 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#![allow(clippy::module_name_repetitions)]

use std::borrow::Cow;
use std::collections::HashMap;
use std::ptr::NonNull;
use std::sync::Arc;

use chunk_item::WithSubset;
use itertools::Itertools;
use numpy::npyffi::PyArrayObject;
use numpy::{PyArrayDescrMethods, PyUntypedArray, PyUntypedArrayMethods};
use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError};
Expand All @@ -14,12 +17,16 @@ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon_iter_concurrent_limit::iter_concurrent_limit;
use unsafe_cell_slice::UnsafeCellSlice;
use zarrs::array::codec::{ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder};
use utils::is_whole_chunk;
use zarrs::array::codec::{
ArrayPartialDecoderTraits, ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder,
};
use zarrs::array::{
copy_fill_value_into, update_array_bytes, ArrayBytes, ArraySize, CodecChain, FillValue,
};
use zarrs::array_subset::ArraySubset;
use zarrs::metadata::v3::MetadataV3;
use zarrs::storage::StoreKey;

mod chunk_item;
mod concurrency;
Expand Down Expand Up @@ -265,15 +272,44 @@ impl CodecPipelineImpl {
return Ok(());
};

// Assemble partial decoders ahead of time and in parallel
let partial_chunk_descriptions = chunk_descriptions
.iter()
.filter(|item| !(is_whole_chunk(item)))
.unique_by(|item| item.key())
.collect::<Vec<_>>();
let mut partial_decoder_cache: HashMap<StoreKey, Arc<dyn ArrayPartialDecoderTraits>> =
HashMap::new().into();
if partial_chunk_descriptions.len() > 0 {
let key_decoder_pairs = iter_concurrent_limit!(
chunk_concurrent_limit,
partial_chunk_descriptions,
map,
|item| {
let input_handle = self.stores.decoder(item)?;
let partial_decoder = self
.codec_chain
.clone()
.partial_decoder(
Arc::new(input_handle),
item.representation(),
&codec_options,
)
.map_py_err::<PyValueError>()?;
Ok((item.key().clone(), partial_decoder))
}
)
.collect::<PyResult<Vec<_>>>()?;
partial_decoder_cache.extend(key_decoder_pairs);
}

py.allow_threads(move || {
// FIXME: the `decode_into` methods only support fixed length data types.
// For variable length data types, need a codepath with non `_into` methods.
// Collect all the subsets and copy into value on the Python side?
let update_chunk_subset = |item: chunk_item::WithSubset| {
// See zarrs::array::Array::retrieve_chunk_subset_into
if item.chunk_subset.start().iter().all(|&o| o == 0)
&& item.chunk_subset.shape() == item.representation().shape_u64()
{
if is_whole_chunk(&item) {
// See zarrs::array::Array::retrieve_chunk_into
if let Some(chunk_encoded) = self.stores.get(&item)? {
// Decode the encoded data into the output buffer
Expand Down Expand Up @@ -308,12 +344,10 @@ impl CodecPipelineImpl {
}
}
} else {
let input_handle = Arc::new(self.stores.decoder(&item)?);
let partial_decoder = self
.codec_chain
.clone()
.partial_decoder(input_handle, item.representation(), &codec_options)
.map_py_err::<PyValueError>()?;
let key = item.key();
let partial_decoder = partial_decoder_cache.get(key).ok_or_else(|| {
PyRuntimeError::new_err(format!("Partial decoder not found for key: {key}"))
})?;
unsafe {
// SAFETY:
// - output is an array with output_shape elements of the item.representation data type,
Expand Down
7 changes: 7 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::fmt::Display;
use numpy::{PyUntypedArray, PyUntypedArrayMethods};
use pyo3::{Bound, PyErr, PyResult, PyTypeInfo};

use crate::{ChunksItem, WithSubset};

pub(crate) trait PyErrExt<T> {
fn map_py_err<PE: PyTypeInfo>(self) -> PyResult<T>;
}
Expand All @@ -29,3 +31,8 @@ impl PyUntypedArrayExt for Bound<'_, PyUntypedArray> {
})
}
}

pub fn is_whole_chunk(item: &WithSubset) -> bool {
item.chunk_subset.start().iter().all(|&o| o == 0)
&& item.chunk_subset.shape() == item.representation().shape_u64()
}
Loading