Skip to content

Memory profiling and indexing arrays with shards #3641

@ljstrnadiii

Description

@ljstrnadiii

I noticed extremely high memory usage when reading a small subset from a sharded array so I reproduced and wanted to share results.

Notes:

  • array := shape=(1, 10, 8192 * 3, 8192 * 3), chunks=(1, 1, 4096, 4096), shards=(1, 1, 8192, 8192), dtype="float32"
  • Sometimes I need to slice out a subset along a dimension e.g. array[:, [0, 3, 6], ...]. (chunksize may be 1...N)

Case 1: Reading a 2.5Gb subset iterating over dim=1 in serial

Duration: 0:00:27.586000
Total number of allocations: 3378254
Total number of frames seen: 583
Peak memory usage: 1.6 GB
Image

Case 2: Reading a 2.5Gb subset using ":" over dim=1

Duration: 0:00:11.887000
Total number of allocations: 4115848
Total number of frames seen: 600
Peak memory usage: 5.7 GB
Image

Case 3: Reading a 2.5Gb subset using [0...9] (all) over dim=1

Duration: 0:01:33.994000
Total number of allocations: 4179800
Total number of frames seen: 814
Peak memory usage: 23.9 GB
Image

Versions

python: 3.11
zarr-python: 3.1.3
obstore: 0.8.2
numpy: 2.3.3
memray: 1.18.0

Reproducibility

Please don't mind the Flyte code... Just a helper for me to run this remotely.

import subprocess

import memray
import numpy as np
import zarr
from boto3 import Session
from flytekit import FlyteFile, Resources, task
from obstore.auth.boto3 import Boto3CredentialProvider
from obstore.store import S3Store
from yarl import URL
from zarr.storage import ObjectStore


def flamegraph(bin_path: str, html_path: str) -> None:
    subprocess.run(
        ["python", "-m", "memray", "flamegraph", "-o", html_path, bin_path],
        check=True,
    )


@task(resources=Resources(cpu=("8", "8"), mem=("32Gi", "32Gi")))
def profile_zarr_indexing_task(path: str) -> list[FlyteFile]:
    url = URL(path)
    store = ObjectStore(
        S3Store(
            bucket=url.host,
            prefix=url.path.lstrip("/"),
            credential_provider=Boto3CredentialProvider(Session(region_name="us-west-2")),
            config={"region": "us-west-2"},
        ),
    )

    arr = zarr.create_array(
        store=store,
        shape=(1, 10, 8192 * 3, 8192 * 3),
        chunks=(1, 1, 4096, 4096),
        shards=(1, 1, 8192, 8192),
        dtype="float32",
        overwrite=True,
    )

    rng = np.random.default_rng(seed=42)
    arr[:] = rng.random(arr.shape, dtype="float32")

    arr = zarr.open_array(store, mode="r")

    # Case 3
    with (
        memray.Tracker("fancy-indexing.bin", trace_python_allocators=True),
        zarr.config.set({"async.concurrency": 8, "threading.max_workers": 8}),
    ):
        _ = arr[0, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 8191:16384, 8191:16384]

    # Case 2
    with (
        memray.Tracker("colon-slice-indexing.bin", trace_python_allocators=True),
        zarr.config.set({"async.concurrency": 8, "threading.max_workers": 8}),
    ):
        _ = arr[0, :, 8191:16384, 8191:16384]

    # Case 1
    with (
        memray.Tracker("serial-chunk-indexing.bin", trace_python_allocators=True),
        zarr.config.set({"async.concurrency": 8, "threading.max_workers": 8}),
    ):
        for i in range(arr.shape[1]):
            _ = arr[0, i : i + 1, 8191:16384, 8191:16384]

    flamegraph("fancy-indexing.bin", "fancy-indexing.html")
    flamegraph("colon-slice-indexing.bin", "colon-slice-indexing.html")
    flamegraph("serial-chunk-indexing.bin", "serial-chunk-indexing.html")

    return [
        FlyteFile("fancy-indexing.html"),
        FlyteFile("colon-slice-indexing.html"),
        FlyteFile("serial-chunk-indexing.html"),
    ]

single-indexing.html
fancy-indexing.html
serial-chunk-indexing.html
colon-slice-indexing.html

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions