Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ opendal = { version = "0.51.0", features = ["services-http"] }
tokio = { version = "1.41.1", features = ["rt-multi-thread"] }
zarrs_opendal = "0.5.0"
zarrs_metadata = "0.3.3" # require recent zarr-python compatibility fixes (remove with zarrs 0.20)
itertools = "0.9.0"

[profile.release]
lto = true
57 changes: 46 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#![allow(clippy::module_name_repetitions)]

use std::borrow::Cow;
use std::collections::HashMap;
use std::ptr::NonNull;
use std::sync::Arc;

use chunk_item::WithSubset;
use itertools::Itertools;
use numpy::npyffi::PyArrayObject;
use numpy::{PyArrayDescrMethods, PyUntypedArray, PyUntypedArrayMethods};
use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError};
Expand All @@ -14,12 +17,16 @@ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon_iter_concurrent_limit::iter_concurrent_limit;
use unsafe_cell_slice::UnsafeCellSlice;
use zarrs::array::codec::{ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder};
use utils::is_whole_chunk;
use zarrs::array::codec::{
ArrayPartialDecoderTraits, ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder,
};
use zarrs::array::{
copy_fill_value_into, update_array_bytes, ArrayBytes, ArraySize, CodecChain, FillValue,
};
use zarrs::array_subset::ArraySubset;
use zarrs::metadata::v3::MetadataV3;
use zarrs::storage::StoreKey;

mod chunk_item;
mod concurrency;
Expand Down Expand Up @@ -265,15 +272,41 @@ impl CodecPipelineImpl {
return Ok(());
};

// Assemble partial decoders ahead of time and in parallel
let partial_chunk_descriptions = chunk_descriptions
.iter()
.filter(|item| !(is_whole_chunk(item)))
.unique_by(|item| item.key())
.collect::<Vec<_>>();
let mut partial_decoder_cache: HashMap<StoreKey, Arc<dyn ArrayPartialDecoderTraits>> =
HashMap::new().into();
if partial_chunk_descriptions.len() > 0 {
let key_decoder_pairs = partial_chunk_descriptions
.into_par_iter()
.map(|item| {
let input_handle = self.stores.decoder(item)?;
let partial_decoder = self
.codec_chain
.clone()
.partial_decoder(
Arc::new(input_handle),
item.representation(),
&codec_options,
)
.map_py_err::<PyValueError>()?;
Ok((item.key().clone(), partial_decoder))
})
.collect::<PyResult<Vec<_>>>()?;
partial_decoder_cache.extend(key_decoder_pairs);
}

py.allow_threads(move || {
// FIXME: the `decode_into` methods only support fixed length data types.
// For variable length data types, need a codepath with non `_into` methods.
// Collect all the subsets and copy into value on the Python side?
let update_chunk_subset = |item: chunk_item::WithSubset| {
// See zarrs::array::Array::retrieve_chunk_subset_into
if item.chunk_subset.start().iter().all(|&o| o == 0)
&& item.chunk_subset.shape() == item.representation().shape_u64()
{
if is_whole_chunk(&item) {
// See zarrs::array::Array::retrieve_chunk_into
if let Some(chunk_encoded) = self.stores.get(&item)? {
// Decode the encoded data into the output buffer
Expand Down Expand Up @@ -308,18 +341,20 @@ impl CodecPipelineImpl {
}
}
} else {
let input_handle = Arc::new(self.stores.decoder(&item)?);
let partial_decoder = self
.codec_chain
.clone()
.partial_decoder(input_handle, item.representation(), &codec_options)
.map_py_err::<PyValueError>()?;
let key = item.key();
let partial_decoder: PyResult<&Arc<dyn ArrayPartialDecoderTraits>> =
match partial_decoder_cache.get(key) {
Some(e) => Ok(e),
None => Err(PyRuntimeError::new_err(format!(
"Partial decoder not found for key: {key}"
))),
};
unsafe {
// SAFETY:
// - output is an array with output_shape elements of the item.representation data type,
// - item.subset is within the bounds of output_shape.
// - item.chunk_subset has the same number of elements as item.subset.
partial_decoder.partial_decode_into(
partial_decoder?.partial_decode_into(
&item.chunk_subset,
&output,
&output_shape,
Expand Down
7 changes: 7 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::fmt::Display;
use numpy::{PyUntypedArray, PyUntypedArrayMethods};
use pyo3::{Bound, PyErr, PyResult, PyTypeInfo};

use crate::{ChunksItem, WithSubset};

pub(crate) trait PyErrExt<T> {
fn map_py_err<PE: PyTypeInfo>(self) -> PyResult<T>;
}
Expand All @@ -29,3 +31,8 @@ impl PyUntypedArrayExt for Bound<'_, PyUntypedArray> {
})
}
}

pub fn is_whole_chunk(item: &WithSubset) -> bool {
item.chunk_subset.start().iter().all(|&o| o == 0)
&& item.chunk_subset.shape() == item.representation().shape_u64()
}
Loading