diff --git a/src/awkward/_nplikes/array_module.py b/src/awkward/_nplikes/array_module.py index 38ca0c4fbb..ae97f2a8b0 100644 --- a/src/awkward/_nplikes/array_module.py +++ b/src/awkward/_nplikes/array_module.py @@ -80,6 +80,7 @@ def asarray( obj._dtype if dtype is None else dtype, lambda: self.asarray(obj.materialize(), dtype=dtype, copy=copy), lambda: obj.shape, + __enable_caching__=obj.__enable_caching__, ) if copy: return self._module.array(obj, dtype=dtype, copy=True) @@ -108,6 +109,7 @@ def ascontiguousarray( x._dtype, lambda: self.ascontiguousarray(x.materialize()), # type: ignore[arg-type] lambda: x.shape, + __enable_caching__=x.__enable_caching__, ) else: return self._module.ascontiguousarray(x) @@ -356,6 +358,7 @@ def reshape( x.dtype, lambda: self.reshape(x.materialize(), next_shape, copy=copy), # type: ignore[arg-type] None, + __enable_caching__=x.__enable_caching__, ) if copy is None: diff --git a/src/awkward/_nplikes/virtual.py b/src/awkward/_nplikes/virtual.py index 7ddb3d7d83..53a2392f76 100644 --- a/src/awkward/_nplikes/virtual.py +++ b/src/awkward/_nplikes/virtual.py @@ -81,6 +81,7 @@ def __init__( generator: Callable[[], ArrayLike], shape_generator: Callable[[], tuple[ShapeItem, ...]] | None = None, __wrap_generator_asarray__: bool = False, + __enable_caching__: bool = True, ) -> None: if not nplike.supports_virtual_arrays: raise TypeError( @@ -104,6 +105,8 @@ def __init__( self._generator = generator self._shape_generator = shape_generator + self.__enable_caching__ = __enable_caching__ + @property def dtype(self) -> DType: return self._dtype @@ -180,9 +183,11 @@ def materialize(self) -> ArrayLike: f"{type(self).__name__} had dtype {self._dtype} before materialization while the materialized array has dtype {array.dtype}" ) self._shape = array.shape - self._array = array - self._shape_generator = assert_never - self._generator = assert_never + if self.__enable_caching__: + self._array = array + self._shape_generator = assert_never + self._generator = assert_never + return array return self._array # type: ignore[return-value] @property @@ -205,6 +210,7 @@ def T(self): self._dtype, lambda: self.materialize().T, lambda: self.shape[::-1], + __enable_caching__=self.__enable_caching__, ) def view(self, dtype: DTypeLike) -> Self: @@ -237,6 +243,7 @@ def view(self, dtype: DTypeLike) -> Self: dtype, lambda: self.materialize().view(dtype), None, + __enable_caching__=self.__enable_caching__, ) @property @@ -263,6 +270,7 @@ def byteswap(self, inplace=False): self._dtype, lambda: self.materialize().byteswap(inplace=inplace), lambda: self.shape, + __enable_caching__=self.__enable_caching__, ) def tobytes(self, order="C") -> bytes: @@ -275,6 +283,7 @@ def __copy__(self) -> VirtualNDArray: self._dtype, self._generator, self._shape_generator, + __enable_caching__=self.__enable_caching__, ) new_virtual._array = self._array return new_virtual @@ -287,6 +296,7 @@ def __deepcopy__(self, memo) -> VirtualNDArray: self._dtype, lambda: copy.deepcopy(current_generator(), memo), self._shape_generator, + __enable_caching__=self.__enable_caching__, ) new_virtual._array = ( copy.deepcopy(self._array, memo) @@ -340,6 +350,7 @@ def __getitem__(self, index): self._dtype, lambda: self.materialize()[index], None, + __enable_caching__=self.__enable_caching__, ) else: return self.materialize().__getitem__(index) diff --git a/src/awkward/_pickle.py b/src/awkward/_pickle.py index 41449f4d28..d47c4b930a 100644 --- a/src/awkward/_pickle.py +++ b/src/awkward/_pickle.py @@ -115,6 +115,7 @@ def unpickle_array_schema_1( buffer_key="{form_key}-{attribute}", byteorder="<", simplify=False, + enable_virtualarray_caching=True, ) @@ -141,6 +142,7 @@ def unpickle_record_schema_1( buffer_key="{form_key}-{attribute}", byteorder="<", simplify=False, + enable_virtualarray_caching=True, ) layout = LowLevelRecord(array_layout, at) return Record(layout, behavior=behavior, attrs=attrs) diff --git a/src/awkward/forms/form.py b/src/awkward/forms/form.py index e0d3babb0f..bc7d76bd5c 100644 --- a/src/awkward/forms/form.py +++ b/src/awkward/forms/form.py @@ -540,6 +540,7 @@ def length_zero_array( behavior=behavior, attrs=None, simplify=False, + enable_virtualarray_caching=True, ) def length_one_array( @@ -696,6 +697,7 @@ def prepare(form, multiplier): behavior=behavior, attrs=None, simplify=False, + enable_virtualarray_caching=True, ) def _expected_from_buffers( diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index f87643993e..dafab10d41 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -2953,6 +2953,7 @@ def snapshot(self): backend="cpu", byteorder=ak._util.native_byteorder, simplify=True, + enable_virtualarray_caching=True, highlevel=True, behavior=self._behavior, attrs=self._attrs, diff --git a/src/awkward/operations/ak_from_avro_file.py b/src/awkward/operations/ak_from_avro_file.py index dcd49d8f6c..fd08c50dce 100644 --- a/src/awkward/operations/ak_from_avro_file.py +++ b/src/awkward/operations/ak_from_avro_file.py @@ -73,5 +73,6 @@ def _impl(form, length, container, highlevel, behavior, attrs): highlevel=highlevel, behavior=behavior, simplify=True, + enable_virtualarray_caching=True, attrs=attrs, ) diff --git a/src/awkward/operations/ak_from_buffers.py b/src/awkward/operations/ak_from_buffers.py index 98e7dccf7b..570d6db5fc 100644 --- a/src/awkward/operations/ak_from_buffers.py +++ b/src/awkward/operations/ak_from_buffers.py @@ -36,6 +36,7 @@ def from_buffers( backend="cpu", byteorder="<", allow_noncanonical_form=False, + enable_virtualarray_caching=True, highlevel=True, behavior=None, attrs=None, @@ -63,6 +64,14 @@ def from_buffers( allow_noncanonical_form (bool): If True, non-canonical forms will be simplified to produce arrays with canonical layouts; otherwise, an exception will be thrown for such forms. + enable_virtualarray_caching (bool or callable): If True (the default), + all VirtualNDArray buffers that get created will cache their + materialized buffers on themselves when they get materialized. + If a callable is given, it must accept two arguments, `form_key` and `attribute`, + and return a boolean indicating whether caching should be enabled for the + given buffer. The `form_key` and `attribute` are the same as those passed to + the `buffer_key` function. If False, all VirtualNDArrays will not cache + their materialized buffers on themselves. highlevel (bool): If True, return an #ak.Array; otherwise, return a low-level #ak.contents.Content subclass. behavior (None or dict): Custom #ak.behavior for the output array, if @@ -118,6 +127,7 @@ def from_buffers( behavior, attrs, allow_noncanonical_form, + enable_virtualarray_caching, ) @@ -132,6 +142,7 @@ def _impl( behavior, attrs, simplify, + enable_virtualarray_caching, ): backend = regularize_backend(backend) @@ -152,6 +163,17 @@ def _impl( "'form' argument must be a Form or its Python dict/JSON string representation" ) + if isinstance(enable_virtualarray_caching, bool): + + def enable_caching_function(form_key, attribute): + return enable_virtualarray_caching + elif callable(enable_virtualarray_caching): + enable_caching_function = enable_virtualarray_caching + else: + raise TypeError( + "'enable_virtualarray_caching' argument must be a boolean or a callable" + ) + getkey = regularize_buffer_key(buffer_key) out = _reconstitute( @@ -164,6 +186,7 @@ def _impl( simplify, field_path=(), shape_generator=lambda: (length,), + enable_virtualarray_caching=enable_caching_function, ) return wrap_layout(out, highlevel=highlevel, attrs=attrs, behavior=behavior) @@ -177,6 +200,7 @@ def _from_buffer( byteorder: str, field_path: tuple, shape_generator: Callable | None = None, + enable_virtualarray_caching: bool = False, ) -> ArrayLike: if isinstance(buffer, VirtualNDArray): # This is the case for VirtualNDArrays @@ -208,7 +232,14 @@ def _from_buffer( def generator(): (length,) = cached_shape_generator() return _from_buffer( - nplike, buffer(), dtype, length, byteorder, field_path, None + nplike, + buffer(), + dtype, + length, + byteorder, + field_path, + None, + False, ) # also store a ref to the original/raw buffer generator @@ -222,6 +253,7 @@ def generator(): generator=generator, shape_generator=cached_shape_generator, __wrap_generator_asarray__=True, + __enable_caching__=enable_virtualarray_caching, ) # Unknown-length information implies that we didn't load shape-buffers (offsets, etc) # for the parent of this node. Thus, this node and its children *must* only @@ -264,6 +296,7 @@ def _reconstitute( simplify, field_path, shape_generator, + enable_virtualarray_caching, ): if isinstance(form, ak.forms.EmptyForm): if length is not unknown_length and length != 0: @@ -291,6 +324,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=_shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "data" + ), ) if form.inner_shape != (): data = backend.nplike.reshape(data, (length, *form.inner_shape)) @@ -310,6 +346,7 @@ def _shape_generator(): simplify, field_path, shape_generator, + enable_virtualarray_caching, ) if simplify: make = ak.contents.UnmaskedArray.simplified @@ -340,6 +377,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=_shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "mask" + ), ) content = _reconstitute( form.content, @@ -351,6 +391,7 @@ def _shape_generator(): simplify, field_path, shape_generator, + enable_virtualarray_caching, ) if simplify: make = ak.contents.BitMaskedArray.simplified @@ -378,6 +419,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "mask" + ), ) content = _reconstitute( form.content, @@ -389,6 +433,7 @@ def _shape_generator(): simplify, field_path, shape_generator, + enable_virtualarray_caching, ) if simplify: make = ak.contents.ByteMaskedArray.simplified @@ -411,6 +456,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "index" + ), ) def _adjust_length(index): @@ -433,6 +481,7 @@ def _shape_generator(): simplify, field_path, _shape_generator, + enable_virtualarray_caching, ) if simplify: make = ak.contents.IndexedOptionArray.simplified @@ -454,6 +503,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "index" + ), ) def _adjust_length(index): @@ -480,6 +532,7 @@ def _shape_generator(): simplify, field_path, _shape_generator, + enable_virtualarray_caching, ) if simplify: make = ak.contents.IndexedArray.simplified @@ -502,6 +555,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "starts" + ), ) stops = _from_buffer( backend.nplike, @@ -511,6 +567,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "stops" + ), ) def _adjust_length(starts, stops): @@ -536,6 +595,7 @@ def _shape_generator(): simplify, field_path, _shape_generator, + enable_virtualarray_caching, ) return ak.contents.ListArray( ak.index.Index(starts), @@ -559,6 +619,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=_shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "offsets" + ), ) # next length @@ -583,6 +646,7 @@ def _shape_generator(): simplify, field_path, _shape_generator, + enable_virtualarray_caching, ) return ak.contents.ListOffsetArray( ak.index.Index(offsets), @@ -611,6 +675,7 @@ def _shape_generator(): simplify, field_path, _shape_generator, + enable_virtualarray_caching, ) return ak.contents.RegularArray( content, @@ -631,6 +696,7 @@ def _shape_generator(): simplify, (*field_path, field), shape_generator, + enable_virtualarray_caching, ) for content, field in zip(form.contents, form.fields) ] @@ -653,6 +719,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "tags" + ), ) index = _from_buffer( backend.nplike, @@ -662,6 +731,9 @@ def _shape_generator(): byteorder=byteorder, field_path=field_path, shape_generator=shape_generator, + enable_virtualarray_caching=enable_virtualarray_caching( + form.form_key, "index" + ), ) def _adjust_length(index, tags, tag): @@ -699,6 +771,7 @@ def _shape_generator(tag): simplify, field_path, _shape_generators[i], + enable_virtualarray_caching, ) for i, content in enumerate(form.contents) ] diff --git a/src/awkward/operations/ak_from_iter.py b/src/awkward/operations/ak_from_iter.py index 7f06e8cc14..e038fdad7e 100644 --- a/src/awkward/operations/ak_from_iter.py +++ b/src/awkward/operations/ak_from_iter.py @@ -112,5 +112,6 @@ def _impl(iterable, highlevel, behavior, allow_record, initial, resize, attrs): highlevel=highlevel, behavior=behavior, simplify=True, + enable_virtualarray_caching=True, attrs=attrs, )[0] diff --git a/src/awkward/operations/ak_from_safetensors.py b/src/awkward/operations/ak_from_safetensors.py index 3eff1a5daf..f161fe8e06 100644 --- a/src/awkward/operations/ak_from_safetensors.py +++ b/src/awkward/operations/ak_from_safetensors.py @@ -145,6 +145,7 @@ def maybe_virtualize(x): backend=backend, byteorder=byteorder, simplify=allow_noncanonical_form, + enable_virtualarray_caching=True, highlevel=highlevel, behavior=behavior, attrs=attrs, diff --git a/tests/test_3741_virtualarray_caching.py b/tests/test_3741_virtualarray_caching.py new file mode 100644 index 0000000000..d5831bc693 --- /dev/null +++ b/tests/test_3741_virtualarray_caching.py @@ -0,0 +1,87 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import numpy as np +import pytest + +import awkward as ak +from awkward._nplikes.shape import unknown_length + + +@pytest.mark.parametrize("offsets_length", [5, unknown_length]) +@pytest.mark.parametrize("content_length", [9, unknown_length]) +def test(offsets_length, content_length): + offset_generator = lambda: np.array([0, 2, 4, 5, 6], dtype=np.int64) # noqa: E731 + data_generator = lambda: np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) # noqa: E731 + buffers = {"node0-offsets": offset_generator, "node1-data": data_generator} + form = ak.forms.ListOffsetForm( + "i64", ak.forms.NumpyForm("int64", form_key="node1"), form_key="node0" + ) + + array = ak.from_buffers(form, 4, buffers, enable_virtualarray_caching=True) + assert array.to_list() == [[1, 2], [3, 4], [5], [6]] + assert array.layout.is_all_materialized + assert ak.materialize(array).to_list() == [[1, 2], [3, 4], [5], [6]] + + array = ak.from_buffers(form, 4, buffers, enable_virtualarray_caching=False) + assert array.to_list() == [[1, 2], [3, 4], [5], [6]] + assert not array.layout.is_any_materialized + assert ak.materialize(array).to_list() == [[1, 2], [3, 4], [5], [6]] + + array = ak.from_buffers( + form, 4, buffers, enable_virtualarray_caching=lambda form_key, attribute: True + ) + assert array.to_list() == [[1, 2], [3, 4], [5], [6]] + assert array.layout.is_all_materialized + assert ak.materialize(array).to_list() == [[1, 2], [3, 4], [5], [6]] + + array = ak.from_buffers( + form, 4, buffers, enable_virtualarray_caching=lambda form_key, attribute: False + ) + assert array.to_list() == [[1, 2], [3, 4], [5], [6]] + assert not array.layout.is_any_materialized + assert ak.materialize(array).to_list() == [[1, 2], [3, 4], [5], [6]] + + array = ak.from_buffers( + form, + 4, + buffers, + enable_virtualarray_caching=lambda form_key, attribute: attribute != "data", + ) + assert array.to_list() == [[1, 2], [3, 4], [5], [6]] + assert array.layout.offsets.is_all_materialized + assert not array.layout.content.is_any_materialized + assert ak.materialize(array).to_list() == [[1, 2], [3, 4], [5], [6]] + + array = ak.from_buffers( + form, + 4, + buffers, + enable_virtualarray_caching=lambda form_key, attribute: attribute == "data", + ) + assert array.to_list() == [[1, 2], [3, 4], [5], [6]] + assert not array.layout.offsets.is_any_materialized + assert array.layout.content.is_all_materialized + assert ak.materialize(array).to_list() == [[1, 2], [3, 4], [5], [6]] + + array = ak.from_buffers( + form, + 4, + buffers, + enable_virtualarray_caching=lambda form_key, attribute: attribute != "offsets", + ) + assert array.to_list() == [[1, 2], [3, 4], [5], [6]] + assert not array.layout.offsets.is_any_materialized + assert array.layout.content.is_all_materialized + assert ak.materialize(array).to_list() == [[1, 2], [3, 4], [5], [6]] + + array = ak.from_buffers( + form, + 4, + buffers, + enable_virtualarray_caching=lambda form_key, attribute: attribute == "offsets", + ) + assert array.to_list() == [[1, 2], [3, 4], [5], [6]] + assert array.layout.offsets.is_all_materialized + assert ak.materialize(array).to_list() == [[1, 2], [3, 4], [5], [6]]