diff --git a/docs/source/notebooks/data/zea_data_example.ipynb b/docs/source/notebooks/data/zea_data_example.ipynb index 977930d25..5bb4723be 100644 --- a/docs/source/notebooks/data/zea_data_example.ipynb +++ b/docs/source/notebooks/data/zea_data_example.ipynb @@ -9,7 +9,7 @@ "\n", "1. Loading data from single file with `zea.File`\n", "2. Loading data from a group of files with `zea.Dataset`\n", - "3. Loading data in batches with dataloading utilities with `zea.backend.tensorflow.make_dataloader`" + "3. Loading data in batches with dataloading utilities with `zea.Dataloader`" ] }, { @@ -89,9 +89,8 @@ "import matplotlib.pyplot as plt\n", "\n", "import zea\n", - "from zea import init_device, load_file\n", - "from zea.visualize import set_mpl_style\n", - "from zea.backend.tensorflow import make_dataloader" + "from zea import init_device, load_file, Dataloader\n", + "from zea.visualize import set_mpl_style" ] }, { @@ -378,9 +377,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Loading data with `make_dataloader`\n", + "## Loading data with `Dataloader`\n", "\n", - "In machine and deep learning workflows, we often want more features like batching, shuffling, and parallel data loading. The `zea.backend.tensorflow.make_dataloader` function provides a convenient way to create a TensorFlow data loader from a zea dataset. This does require a working TensorFlow installation, but does work in combination with any other backend as well. This dataloader is particularly useful for training models. It is important that there is some consistency in the dataset, which is not the case for [PICMUS](https://www.creatis.insa-lyon.fr/Challenge/IEEE_IUS_2016/home). Therefore in this example we will use a small part of the [CAMUS](https://www.creatis.insa-lyon.fr/Challenge/camus/) dataset." + "In machine and deep learning workflows, we often want more features like batching, shuffling, and parallel data loading. The `zea.Dataloader` class provides a convenient way to create a high-performance data loader from a zea dataset. It is built on Grain and does not require TensorFlow. This dataloader is particularly useful for training models. Consistency of shape is preferred, which is not the case for [PICMUS](https://www.creatis.insa-lyon.fr/Challenge/IEEE_IUS_2016/home). Therefore in this example we will use a small part of the [CAMUS](https://www.creatis.insa-lyon.fr/Challenge/camus/) dataset." ] }, { @@ -460,7 +459,7 @@ ], "source": [ "dataset_path = \"hf://zeahub/camus-sample/val\"\n", - "dataloader = make_dataloader(\n", + "dataloader = Dataloader(\n", " dataset_path,\n", " key=\"data/image_sc\",\n", " batch_size=4,\n", diff --git a/docs/source/notebooks/metrics.rst b/docs/source/notebooks/metrics.rst index 8e6a40dbf..c95030247 100644 --- a/docs/source/notebooks/metrics.rst +++ b/docs/source/notebooks/metrics.rst @@ -1,5 +1,5 @@ Metrics -======= +======== .. toctree:: :maxdepth: 1 diff --git a/docs/source/notebooks/metrics/lpips_example.ipynb b/docs/source/notebooks/metrics/lpips_example.ipynb index 1e9a3f084..5c46f1b44 100644 --- a/docs/source/notebooks/metrics/lpips_example.ipynb +++ b/docs/source/notebooks/metrics/lpips_example.ipynb @@ -77,8 +77,7 @@ "import numpy as np\n", "from keras import ops\n", "\n", - "from zea import init_device\n", - "from zea.backend.tensorflow.dataloader import make_dataloader\n", + "from zea import init_device, Dataloader\n", "from zea.models.lpips import LPIPS\n", "from zea.visualize import plot_image_grid, set_mpl_style" ] @@ -165,7 +164,7 @@ ], "source": [ "n_imgs = 9\n", - "val_dataset = make_dataloader(\n", + "val_dataset = Dataloader(\n", " \"hf://zeahub/camus-sample/val\",\n", " key=\"data/image\",\n", " batch_size=n_imgs,\n", diff --git a/docs/source/notebooks/metrics/myocardial_quality_example.ipynb b/docs/source/notebooks/metrics/myocardial_quality_example.ipynb index 8185e84c2..1d3867760 100644 --- a/docs/source/notebooks/metrics/myocardial_quality_example.ipynb +++ b/docs/source/notebooks/metrics/myocardial_quality_example.ipynb @@ -61,8 +61,7 @@ "from zea.visualize import plot_shape_from_mask\n", "import numpy as np\n", "\n", - "from zea import init_device\n", - "from zea.backend.tensorflow.dataloader import make_dataloader\n", + "from zea import init_device, Dataloader\n", "from zea.visualize import set_mpl_style\n", "from zea.io_lib import matplotlib_figure_to_numpy, save_video\n", "\n", @@ -142,7 +141,7 @@ "# Load a batch and run both models.\n", "n_imgs = 1\n", "INFERENCE_SIZE = 256\n", - "val_dataset = make_dataloader(\n", + "val_dataset = Dataloader(\n", " \"hf://zeahub/camus-sample/val\",\n", " key=\"data/image_sc\",\n", " batch_size=n_imgs,\n", diff --git a/docs/source/notebooks/models/hvae_model_example.ipynb b/docs/source/notebooks/models/hvae_model_example.ipynb index b382be984..3590c55bd 100644 --- a/docs/source/notebooks/models/hvae_model_example.ipynb +++ b/docs/source/notebooks/models/hvae_model_example.ipynb @@ -75,11 +75,10 @@ "\n", "from zea.models.hvae import HierarchicalVAE\n", "\n", - "from zea import init_device\n", + "from zea import init_device, Dataloader\n", "from zea.display import scan_convert_2d\n", "from zea.agent.selection import UniformRandomLines\n", "from zea.visualize import set_mpl_style, plot_image_grid\n", - "from zea.backend.tensorflow.dataloader import make_dataloader\n", "\n", "init_device(verbose=False)\n", "set_mpl_style()" @@ -135,7 +134,7 @@ } ], "source": [ - "val_dataset = make_dataloader(\n", + "val_dataset = Dataloader(\n", " \"hf://zeahub/camus-sample/val\",\n", " key=\"data/image\",\n", " batch_size=batch_size,\n", diff --git a/docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb b/docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb index 3f1c6d4e8..81f691533 100644 --- a/docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb +++ b/docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb @@ -67,10 +67,9 @@ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", "\n", - "from zea import init_device\n", + "from zea import init_device, Dataloader\n", "import matplotlib.pyplot as plt\n", "from keras import ops\n", - "from zea.backend.tensorflow.dataloader import make_dataloader\n", "from zea.visualize import plot_shape_from_mask\n", "from zea.func import translate\n", "from zea.visualize import plot_image_grid, set_mpl_style\n", @@ -116,7 +115,7 @@ "source": [ "n_imgs = 16\n", "INFERENCE_SIZE = 256 # Used for both models\n", - "val_dataset = make_dataloader(\n", + "val_dataset = Dataloader(\n", " \"hf://zeahub/camus-sample/val\",\n", " key=\"data/image_sc\",\n", " batch_size=n_imgs,\n", diff --git a/docs/source/notebooks/models/taesd_autoencoder_example.ipynb b/docs/source/notebooks/models/taesd_autoencoder_example.ipynb index bd67f9f37..a55721faf 100644 --- a/docs/source/notebooks/models/taesd_autoencoder_example.ipynb +++ b/docs/source/notebooks/models/taesd_autoencoder_example.ipynb @@ -80,8 +80,7 @@ "\n", "\n", "import zea\n", - "from zea import init_device\n", - "from zea.backend.tensorflow.dataloader import make_dataloader\n", + "from zea import init_device, Dataloader\n", "from zea.models.taesd import TinyAutoencoder\n", "from zea.visualize import plot_image_grid, set_mpl_style" ] @@ -143,7 +142,7 @@ ], "source": [ "n_imgs = 4\n", - "val_dataset = make_dataloader(\n", + "val_dataset = Dataloader(\n", " \"hf://zeahub/camus-sample/val\",\n", " key=\"data/image\",\n", " batch_size=n_imgs,\n", diff --git a/docs/source/notebooks/models/unet_example.ipynb b/docs/source/notebooks/models/unet_example.ipynb index 3fa11be44..5a9c4b90a 100644 --- a/docs/source/notebooks/models/unet_example.ipynb +++ b/docs/source/notebooks/models/unet_example.ipynb @@ -76,8 +76,7 @@ "import matplotlib.pyplot as plt\n", "from keras import ops\n", "\n", - "from zea import init_device, log\n", - "from zea.backend.tensorflow.dataloader import make_dataloader\n", + "from zea import init_device, log, Dataloader\n", "from zea.models.unet import UNet\n", "from zea.models.lpips import LPIPS\n", "from zea.agent.masks import random_uniform_lines\n", @@ -142,7 +141,7 @@ "source": [ "n_imgs = 8\n", "\n", - "val_dataset = make_dataloader(\n", + "val_dataset = Dataloader(\n", " \"hf://zeahub/camus-sample/val\",\n", " key=\"data/image\",\n", " batch_size=n_imgs,\n", diff --git a/poetry.lock b/poetry.lock index bf9ea5199..0bb4f94e0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -225,6 +225,68 @@ files = [ {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, ] +[[package]] +name = "array-record" +version = "0.8.1" +description = "A file format that achieves a new frontier of IO efficiency" +optional = false +python-versions = ">=3.10" +groups = ["main"] +markers = "python_version == \"3.10\" and sys_platform != \"win32\"" +files = [ + {file = "array_record-0.8.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1bf3fcf0a23591667dbd1fe3d32983e8c864781cef5bfd070885ebd3e39e4594"}, + {file = "array_record-0.8.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:39a69d8b8259bce043a8218f2defd46eb44cba0c2f60b7254fd1daaf26e0ea67"}, + {file = "array_record-0.8.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b63f02209bdd3acbb175248d04e0f3f827950eae72f010e6ec14e2732aaf008e"}, + {file = "array_record-0.8.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e5e8e610f59f5a7958dd5672d4b8e505213195594150fdabb7cc0fa725518ed0"}, + {file = "array_record-0.8.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:66516d7a9ef64a9f4c47e1d0d1cfe52b3a507c6835697ec43718ab6115ce8db5"}, + {file = "array_record-0.8.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:090d82858a931e3754e1062f5c62c20ae764f3e5f24e79ea9054393804383a8d"}, + {file = "array_record-0.8.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:959ee895dfce9decd514d7516f27e1b07660624b239f80987a6f430845acbe6f"}, + {file = "array_record-0.8.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b1d0236d4fa93fc7fa721ba7897a49ee38a758fca4aec79f8824ff4dee1056b8"}, + {file = "array_record-0.8.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:005fb6f4f4e833b2de1c4e51c620ad6a30a893b46dc168039d9bb31b71466070"}, + {file = "array_record-0.8.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a753f10fb9b76a895ef4fd8467848221a75a5d006355e0ec781d1dd7ca69eff6"}, + {file = "array_record-0.8.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:86478265c96df1ec5cf2e124aa4a416e033eef9da9f72b05a776cee9871fb00a"}, + {file = "array_record-0.8.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:37bb802e1fdc98edb60a859552c7e7d494e141cf390bfb7350acfd63467daeae"}, +] + +[package.dependencies] +absl-py = "*" +etils = {version = "*", extras = ["epath"]} + +[package.extras] +beam = ["apache-beam[gcp] (>=2.53.0)", "google-cloud-storage (>=2.11.0)", "tensorflow (>=2.20.0)"] +test = ["grain", "jax", "tensorflow (>=2.20.0)"] + +[[package]] +name = "array-record" +version = "0.8.3" +description = "A file format that achieves a new frontier of IO efficiency" +optional = false +python-versions = ">=3.11" +groups = ["main"] +markers = "python_version >= \"3.11\" and sys_platform != \"win32\"" +files = [ + {file = "array_record-0.8.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8a0caee17d2583638391fb9f90c9c12a43f9665576f18ec3604ba49d62bd2dcc"}, + {file = "array_record-0.8.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9f91636366038afea55061cc539601e353d04ebf32608ae8717df4a4edbd31df"}, + {file = "array_record-0.8.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f4004fcceb7961a2cc0090028e68a3eeb91d3f5b38c00d60b533f7004d1d3f7b"}, + {file = "array_record-0.8.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dada305fa0dfa3fd6f5f263c43ed37546f815e4f33ce30b066175384dead752e"}, + {file = "array_record-0.8.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:718403dc9a364519a5fc440ed9e2784077e965489d5a44c96970d3101101e1cd"}, + {file = "array_record-0.8.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:458f6658de86c9369b23ffe64dcb31393a919b91ae2f15147ee2beb2010b122a"}, + {file = "array_record-0.8.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:fa94fa053c1afaecc183a1c31463fd89d1b7b148cc526095bfc50ab58967e47c"}, + {file = "array_record-0.8.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adfe6c92918363747539c0caa579f568f75aac079ad7bbd4ebcf7fdb02e5461b"}, + {file = "array_record-0.8.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:13df1aec9b38afe98973bf5fe3cf523f83ca904b0423d85319930877145b8f28"}, + {file = "array_record-0.8.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:14facb713fc55ea612cf853c0abbc602032ae619eb85037af3aa41dc4e27eed6"}, + {file = "array_record-0.8.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:06b17a0a90bc74a80c4ebad0b99f0e039adb01746b673db3d7a5632d8b138ddd"}, + {file = "array_record-0.8.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:89d2cf5f4709be3c6c2b30c0d366005223375633ffce66dcb13d927a0bdc5228"}, +] + +[package.dependencies] +absl-py = "*" +etils = {version = "*", extras = ["epath"]} + +[package.extras] +beam = ["apache-beam[gcp] (>=2.53.0)", "google-cloud-storage (>=2.11.0)", "tensorflow (>=2.20.0)"] +test = ["grain", "jax", "tensorflow (>=2.20.0)"] + [[package]] name = "astor" version = "0.8.1" @@ -262,7 +324,7 @@ description = "Classes Without Boilerplate" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"dev\" or extra == \"tests\" or extra == \"docs\"" +markers = "extra == \"dev\" or extra == \"tests\" or extra == \"docs\" or python_version == \"3.10\"" files = [ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, @@ -589,10 +651,9 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "cloudpickle" version = "3.1.1" description = "Pickler class to extend the standard pickle.Pickler functionality" -optional = true +optional = false python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"dev\" or extra == \"tests\"" files = [ {file = "cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e"}, {file = "cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64"}, @@ -901,6 +962,43 @@ files = [ {file = "distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403"}, ] +[[package]] +name = "dm-tree" +version = "0.1.9" +description = "Tree is a library for working with nested data structures." +optional = false +python-versions = ">=3.10" +groups = ["main"] +markers = "python_version == \"3.10\"" +files = [ + {file = "dm_tree-0.1.9-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5d5b28ee2e461b6af65330c143806a6d0945dcabbb8d22d2ba863e6dabd9254e"}, + {file = "dm_tree-0.1.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54d5616015412311df154908069fcf2c2d8786f6088a2ae3554d186cdf2b1e15"}, + {file = "dm_tree-0.1.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831699d2c60a1b38776a193b7143ae0acad0a687d87654e6d3342584166816bc"}, + {file = "dm_tree-0.1.9-cp310-cp310-win_amd64.whl", hash = "sha256:1ae3cbff592bb3f2e197f5a8030de4a94e292e6cdd85adeea0b971d07a1b85f2"}, + {file = "dm_tree-0.1.9-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7d7d784afaeb4b67d87d858261aaf02503939ddc1f09c4cca70728f9892ab004"}, + {file = "dm_tree-0.1.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e660d1779ddcbd1348410d08f67db4870d413a3ec4ba8b4b045bd5ce4bd8f35c"}, + {file = "dm_tree-0.1.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:294dc1cecf87552a45cdd5ddb215e7f5295a5a47c46f1f0a0463c3dd02a527d7"}, + {file = "dm_tree-0.1.9-cp311-cp311-win_amd64.whl", hash = "sha256:12f4cc6cd52a39aa38ff31577b6d79b6136a9a89273a876bf62335c9f65c27bf"}, + {file = "dm_tree-0.1.9-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a8d20eeab7fde77a3ed71f07716021eb0edfb4812a128eb381d108af3a310257"}, + {file = "dm_tree-0.1.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80c43417814b1181d3367b335460bfdd30b79ee187a64220e11f6ddd093a4b15"}, + {file = "dm_tree-0.1.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2334cfe9d2ed4293f9f1c7aefba0657deaab9ea74b5fadd966f6d01d9b6b42d9"}, + {file = "dm_tree-0.1.9-cp312-cp312-win_amd64.whl", hash = "sha256:9020a5ce256fcc83aa4bc190cc96dd66e87685db0a6e501b0c06aa492c2e38fc"}, + {file = "dm_tree-0.1.9-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:cfa33c2e028155810ad1b4e11928707bf47489516763a86e79cab2954d23bf68"}, + {file = "dm_tree-0.1.9-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d05622d074353cf434049206e53c12147903a048c4bd7d77f2800d427413ad78"}, + {file = "dm_tree-0.1.9-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68b0efad76703dd4648586c75618a48cdd671b68c3266fe980e323c15423607"}, + {file = "dm_tree-0.1.9-cp313-cp313-win_amd64.whl", hash = "sha256:e97c34fcb44941c36b7ee81dcdbceba0fbe728bddcc77e5837ab2eb665bcbff8"}, + {file = "dm_tree-0.1.9-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b06e7a5da1c31a82521a60060573527e8d24b9920fdd20b2ec86f08412737598"}, + {file = "dm_tree-0.1.9-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6893fcdc5cf1a4f459cfc383526d35d42e7c671ae565d7e429a2f2cb2cb93e89"}, + {file = "dm_tree-0.1.9-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1f5d1e96b3a7de22b25b13a5eb30f41f8cf9c02dd4479a24920de99e780903c"}, + {file = "dm_tree-0.1.9.tar.gz", hash = "sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b"}, +] + +[package.dependencies] +absl-py = ">=0.6.1" +attrs = ">=18.2.0" +numpy = {version = ">=1.21.2", markers = "python_version >= \"3.10\""} +wrapt = ">=1.11.2" + [[package]] name = "docutils" version = "0.21.2" @@ -927,6 +1025,44 @@ files = [ {file = "entrypoints-0.4.tar.gz", hash = "sha256:b706eddaa9218a19ebcd67b56818f05bb27589b1ca9e8d797b74affad4ccacd4"}, ] +[[package]] +name = "etils" +version = "1.13.0" +description = "Collection of common python utils" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "etils-1.13.0-py3-none-any.whl", hash = "sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb"}, + {file = "etils-1.13.0.tar.gz", hash = "sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607"}, +] + +[package.dependencies] +fsspec = {version = "*", optional = true, markers = "extra == \"epath\""} +importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""} +typing_extensions = {version = "*", optional = true, markers = "extra == \"epath\" or extra == \"epy\""} +zipp = {version = "*", optional = true, markers = "extra == \"epath\""} + +[package.extras] +all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"] +array-types = ["etils[enp]"] +dev = ["chex", "fiddle", "optree", "pydantic", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "tensorflow_datasets", "torch"] +docs = ["etils[all,dev]", "sphinx-apitree[ext]"] +eapp = ["absl-py", "etils[epy]", "simple_parsing"] +ecolab = ["etils[enp]", "etils[epy]", "etils[etree]", "jupyter", "mediapy", "numpy", "packaging", "protobuf"] +edc = ["etils[epy]"] +enp = ["einops", "etils[epy]", "numpy"] +epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"] +epath-gcs = ["etils[epath]", "gcsfs"] +epath-s3 = ["etils[epath]", "s3fs"] +epy = ["typing_extensions"] +etqdm = ["absl-py", "etils[epy]", "tqdm"] +etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"] +etree-dm = ["dm-tree", "etils[etree]"] +etree-jax = ["etils[etree]", "jax[cpu]"] +etree-tf = ["etils[etree]", "tensorflow"] +lazy-imports = ["etils[ecolab]"] + [[package]] name = "exceptiongroup" version = "1.3.0" @@ -1356,6 +1492,85 @@ files = [ [package.dependencies] six = "*" +[[package]] +name = "grain" +version = "0.2.13" +description = "Grain: A library for loading and transforming data for ML training." +optional = false +python-versions = ">=3.10" +groups = ["main"] +markers = "python_version == \"3.10\"" +files = [ + {file = "grain-0.2.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:08eccde723cf5f0548beb9ebc9401b641c853b4a9e1ef184005635d3c561c106"}, + {file = "grain-0.2.13-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:aedca377a5c85a4776bbd7ba6e742805933193ec84d79232764a86025e2eb9d8"}, + {file = "grain-0.2.13-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d92434255aa385d5548d85d606cf3a79cb587f8deb9281c9fc5b144694b32b92"}, + {file = "grain-0.2.13-cp310-cp310-win_amd64.whl", hash = "sha256:5356bd0491f736ce84a9e70e8475aa582d4b4053072777d849f9fcb625f36299"}, + {file = "grain-0.2.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f02a449fbbbfa006690f32a5305b73c048061b9744b86d391d713aed3b52952"}, + {file = "grain-0.2.13-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8227a0f81d200c71cb006dbcb268c490b560058283f9f761e9f5090c98f90d56"}, + {file = "grain-0.2.13-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a059266b51682cc96fd138a566ac8d5474adda1e8954bccbe6b3558c85c8fb65"}, + {file = "grain-0.2.13-cp311-cp311-win_amd64.whl", hash = "sha256:11425b5ae59340d62b784b28a47f98eff1a360435c9cde58198b4f56855f97c5"}, + {file = "grain-0.2.13-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9a8dc2bbf53920af14b833b0125d3c64247b00947c44d37ae8dcfc96e40a8469"}, + {file = "grain-0.2.13-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:97f231969f9f305dd37edbb11a58d1f51b335f5cfefbb9418488a52affab428f"}, + {file = "grain-0.2.13-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21de2a003a2d7971bb2d56f667d66eb9a74f4b1c533157000d8dc42cab1af80c"}, + {file = "grain-0.2.13-cp312-cp312-win_amd64.whl", hash = "sha256:2e994930573fec940a1836f2d0ae679dbf0a3f0c0d7696068fa2c354db12444e"}, + {file = "grain-0.2.13-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:00a3befab959d46e559b31482ddbc1044b4afddd2e95ad36b3e11924133333f8"}, + {file = "grain-0.2.13-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:94e92790fc2f09d54ac9dedbaf16761cfd505bdf022d936dfca46ce3ff7558c8"}, + {file = "grain-0.2.13-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e877ca7ed8ae32508e1264385d659969e3492e1138949dfaf6410572203300ae"}, + {file = "grain-0.2.13-cp313-cp313-win_amd64.whl", hash = "sha256:496017e4485ea4643364d08e4ba96c0a3f433f09a10c89e62c0555c5afc7211e"}, +] + +[package.dependencies] +absl-py = "*" +array-record = {version = ">=0.8.0a1", markers = "sys_platform != \"win32\""} +cloudpickle = "*" +dm-tree = "*" +etils = {version = "*", extras = ["epath", "epy"]} +more-itertools = ">=9.1.0" +numpy = "*" +protobuf = ">=5.28.3" + +[package.extras] +parquet = ["pyarrow"] +testing = ["attrs", "dill", "jax", "jaxlib", "jaxtyping", "pyarrow", "pytest", "tensorflow-datasets"] + +[[package]] +name = "grain" +version = "0.2.15" +description = "Grain: A library for loading and transforming data for ML training." +optional = false +python-versions = ">=3.11" +groups = ["main"] +markers = "python_version >= \"3.11\"" +files = [ + {file = "grain-0.2.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:63b90d0719a97c25a8b30057e224436dfaf3ba975e5baadaf8b3417235a09631"}, + {file = "grain-0.2.15-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:60f5ded548e58251af4a7429779c472cdb5ab716a448c32e500e6eb79957c5e9"}, + {file = "grain-0.2.15-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fcac68693f539f5421560250fe906c62b00daeb58a28a56ac7b8d9dad2ba7483"}, + {file = "grain-0.2.15-cp311-cp311-win_amd64.whl", hash = "sha256:7e8d7114f079cb3cfa9953e699cb44a44b179e9fb2ca4f61e8b7ac8ca90c33ce"}, + {file = "grain-0.2.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:149b195ebd2a2e06e00019af5e83d7d6fef18b6149471ba0baaf6bdfac1a0558"}, + {file = "grain-0.2.15-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9bbf2c57a92fae5d883e04bda35d373202f026cbd9952b9bdd831bda2c60ddd7"}, + {file = "grain-0.2.15-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fecc1553e539b946ed7f3508261f98309d07f899c06af1954186cf7d6a4613b2"}, + {file = "grain-0.2.15-cp312-cp312-win_amd64.whl", hash = "sha256:cc58394bdddd5db1b1e1e2878899a1e78014f60046ead82415011eca36475287"}, + {file = "grain-0.2.15-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df0c4af1442f71effc80a5b44bb33fe5aabb7b3d82d5a3ac035ae40431935c77"}, + {file = "grain-0.2.15-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4766dc1355448cf7b3bdd3cc4f4639bc3433c66c5502954380b9a8f4c2b3fc94"}, + {file = "grain-0.2.15-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ae2635e3bd18b77a7add7152ee91b22aba553fa44a066aefa009d1b923ade4f2"}, + {file = "grain-0.2.15-cp313-cp313-win_amd64.whl", hash = "sha256:9d3cc6edfa4d0de341de0d03ba4a8612d94639bed97fe63af3951db8305c96c0"}, + {file = "grain-0.2.15-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:36e1413c275741712918e528f3f65ca254ce51d60160a2b9558c020147f38e98"}, + {file = "grain-0.2.15-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fedbb410486890c6e410c8d78934097b2755aa2363fbd6137b920ea5be104530"}, + {file = "grain-0.2.15-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1a5e20ffc29391c48b5d8f818742f4ecce957fe59ccff971c139721d31a0bf4"}, + {file = "grain-0.2.15-cp314-cp314-win_amd64.whl", hash = "sha256:d57493fa4361316755cb0564f87e62d1da1dedae4eed53acf179ec029668dee5"}, +] + +[package.dependencies] +absl-py = "*" +array-record = {version = ">=0.8.1", markers = "sys_platform != \"win32\""} +cloudpickle = "*" +etils = {version = "*", extras = ["epath", "epy"]} +numpy = "*" +protobuf = ">=5.28.3" + +[package.extras] +parquet = ["pyarrow"] + [[package]] name = "grpcio" version = "1.73.0" @@ -1647,6 +1862,26 @@ files = [ {file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"}, ] +[[package]] +name = "importlib-resources" +version = "6.5.2" +description = "Read resources from Python packages" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec"}, + {file = "importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] +type = ["pytest-mypy"] + [[package]] name = "iniconfig" version = "2.1.0" @@ -2656,6 +2891,19 @@ numpy = [ [package.extras] dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] +[[package]] +name = "more-itertools" +version = "10.8.0" +description = "More routines for operating on iterables, beyond itertools" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version == \"3.10\"" +files = [ + {file = "more_itertools-10.8.0-py3-none-any.whl", hash = "sha256:52d4362373dcf7c52546bc4af9a86ee7c4579df9a8dc268be0a2f949d376cc9b"}, + {file = "more_itertools-10.8.0.tar.gz", hash = "sha256:f638ddf8a1a0d134181275fb5d58b086ead7c6a72429ad725c67503f13ba30bd"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -6449,7 +6697,7 @@ description = "Module for decorators, wrappers and monkey patching." optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"backends\"" +markers = "extra == \"backends\" or python_version == \"3.10\"" files = [ {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984"}, {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22"}, @@ -6652,6 +6900,26 @@ idna = ">=2.0" multidict = ">=4.0" propcache = ">=0.2.1" +[[package]] +name = "zipp" +version = "3.23.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"}, + {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +type = ["pytest-mypy"] + [extras] backends = ["jax", "tensorflow", "torch"] dev = ["IPython", "PyStemmer", "cloudpickle", "furo", "ipykernel", "ipywidgets", "myst-parser", "nbsphinx", "onnxruntime", "opencv-python-headless", "papermill", "pre-commit", "pytest", "pytest-cov", "ruff", "setuptools", "simpleitk", "sphinx", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx-copybutton", "sphinx-reredirects", "sphinx_design", "sphinxcontrib-autoprogram", "sphinxcontrib-bibtex"] @@ -6665,4 +6933,4 @@ tests = ["cloudpickle", "ipykernel", "ipywidgets", "papermill", "pre-commit", "p [metadata] lock-version = "2.1" python-versions = ">=3.10" -content-hash = "1becbaba428eb2be827a5f718b4499ce043bfa6f34bc8b42e57aab12e3de88f3" +content-hash = "a9a95f36adeafeff2728bcea96c45a7fed44fd1fc0f222f20c71dd129a50f69c" diff --git a/pyproject.toml b/pyproject.toml index 5922d1692..1dabe862a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "wandb >=0.18", "wget >=3.2", "imageio[ffmpeg] >=2.0", + "grain >= 0.2", # can we make these optional or remove? "scikit-image >=0.23", "scikit-learn >=1.4", diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 6415cddce..69f6daf80 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -1,8 +1,7 @@ -"""Test Tensorflow H5 dataloader functions""" +"""Test H5 dataloader functions""" import hashlib import pickle -from copy import deepcopy import h5py import keras @@ -10,14 +9,11 @@ import pytest from keras import ops -from zea import log -from zea.backend.tensorflow.dataloader import make_dataloader from zea.data.augmentations import RandomCircleInclusion -from zea.data.dataloader import MAX_RETRY_ATTEMPTS, H5Generator +from zea.data.dataloader import Dataloader, H5DataSource from zea.data.datasets import Dataset from zea.data.file import File from zea.data.layers import Resizer -from zea.data.utils import json_loads from zea.tools.hf import HFPath from .. import DEFAULT_TEST_SEED @@ -80,15 +76,14 @@ def camus_file(): return CAMUS_FILE -def _get_h5_generator(file_path, key, n_frames, insert_frame_axis, seed=None, validate=True): +def _get_h5_data_source(file_path, key, n_frames, insert_frame_axis, validate=True): file_paths = [file_path] - # Create a H5Generator instance - generator = H5Generator( + + generator = H5DataSource( file_paths=file_paths, key=key, n_frames=n_frames, insert_frame_axis=insert_frame_axis, - seed=seed, validate=validate, ) return generator @@ -108,15 +103,17 @@ def _get_h5_generator(file_path, key, n_frames, insert_frame_axis, seed=None, va ("camus_file", "data/image_sc", 15, False), ], ) -def test_h5_generator(file_path, key, n_frames, insert_frame_axis, request): - """Test the H5Generator class""" +def test_h5_data_source(file_path, key, n_frames, insert_frame_axis, request): + """Test the H5DataSource class""" validate = file_path != "dummy_hdf5" file_path = request.getfixturevalue(file_path) - generator = _get_h5_generator(file_path, key, n_frames, insert_frame_axis, validate=validate) + data_source = _get_h5_data_source( + file_path, key, n_frames, insert_frame_axis, validate=validate + ) - batch_shape = next(generator()).shape + batch_shape = data_source[0].shape if insert_frame_axis: assert batch_shape[-1] == n_frames, ( f"Something went wrong as the last dimension of the batch shape {batch_shape[-1]}" @@ -129,19 +126,6 @@ def test_h5_generator(file_path, key, n_frames, insert_frame_axis, request): ) -def test_h5_generator_shuffle(dummy_hdf5): - """Test the H5Generator class""" - - generator = _get_h5_generator( - dummy_hdf5, "data", 10, False, seed=DEFAULT_TEST_SEED, validate=False - ) - - # Test shuffle - shuffled_items = deepcopy(generator.shuffled_items) - generator._shuffle() - assert shuffled_items != generator.shuffled_items, "The generator indices were not shuffled" - - @pytest.mark.parametrize( "directory, key, n_frames, insert_frame_axis, num_files, total_samples", [ @@ -185,7 +169,7 @@ def test_dataloader( [length // n_frames if not insert_frame_axis else length for length in file_lengths] ) - dataset = make_dataloader( + dataset = Dataloader( directory, batch_size=1, key=key, @@ -216,15 +200,19 @@ def test_dataloader( f" is not equal to the expected length {expected_len_dataset}" ) - # Test shuffling - shuffle_key = {} - for i in range(2): - shuffle_key[i] = "" + # Test shuffling — with very few samples different seeds can produce the + # same permutation, iterate several times and require that at least + # one pair differs. + n_shuffle_iters = 5 + shuffle_keys = [] + for _ in range(n_shuffle_iters): + h = "" for batch in iter(dataset): key = hashlib.md5(pickle.dumps(batch)).hexdigest() - shuffle_key[i] += key + h += key + shuffle_keys.append(h) - assert shuffle_key[0] != shuffle_key[1], "The dataset was not shuffled" + assert len(set(shuffle_keys)) > 1, "The dataset was not shuffled" @pytest.mark.parametrize( @@ -250,7 +238,8 @@ def test_h5_dataset_return_filename( validate = directory != "dummy_hdf5" directory = request.getfixturevalue(directory) - dataset = make_dataloader( + N_AXIS = 3 # n_frames, height, width + dataset = Dataloader( directory, key=key, image_size=image_size, @@ -270,18 +259,31 @@ def test_h5_dataset_return_filename( _, file_dict = batch - assert len(file_dict) == batch_size, ( - "The file_dict should contain the same number of elements as the batch size" - ) + # Check keys + keys = ["filename", "fullpath", "indices"] + for key in keys: + assert key in file_dict, f"The file_dict should contain the key '{key}'" - file_dict = file_dict[0] # get the first file_dict of the batch - file_dict = json_loads(file_dict.numpy()) + # Check batch size and types + keys = ["filename", "fullpath"] + for key in keys: + assert len(file_dict[key]) == batch_size, ( + f"The file_dict['{key}'] should contain the same number of elements as the batch size" + ) + for path in file_dict[key]: + assert isinstance(path, str), f"Each path in file_dict['{key}'] should be a string" - filename = file_dict["filename"] - assert isinstance(filename, str), "The filename should be a string" - fullpath = file_dict["fullpath"] - assert isinstance(fullpath, str), "The fullpath should be a string" - assert "indices" in file_dict, "The file_dict should contain indices" + # indices nests one deeper, because it has one element per axis (n_frames, height, width) + indices = file_dict["indices"] + assert len(indices) == N_AXIS, ( + f"The file_dict['indices'] should contain {N_AXIS} elements in this test" + ) + + for idx in indices: + assert len(idx) == batch_size, ( + "Each axis in file_dict['indices'] should contain the same number of elements " + "as the batch size" + ) @pytest.mark.parametrize( @@ -309,7 +311,7 @@ def test_h5_dataset_resize_types(directory, key, image_size, resize_type, batch_ validate = directory != "dummy_hdf5" directory = request.getfixturevalue(directory) - dataset = make_dataloader( + dataset = Dataloader( directory, key=key, image_size=image_size, @@ -405,7 +407,7 @@ def test_ndim_hdf5_dataset( ): """Test the dataloader with an n-dimensional HDF5 dataset.""" - dataset = make_dataloader( + dataset = Dataloader( ndim_hdf5_dataset_path, key=key, image_size=image_size, @@ -427,61 +429,6 @@ def test_ndim_hdf5_dataset( next(iter(dataset)) -@pytest.mark.parametrize( - "mock_error_count, expected_retries, should_succeed", - [ - (1, 1, True), # One error, should succeed on retry - ( - MAX_RETRY_ATTEMPTS - 1, - MAX_RETRY_ATTEMPTS - 1, - True, - ), # Two errors, should succeed on third try - ( - MAX_RETRY_ATTEMPTS + 1, - MAX_RETRY_ATTEMPTS, - False, - ), # Too many errors, should fail after max retries - ], -) -def test_h5_file_retry_count( - mock_error_count, expected_retries, should_succeed, dummy_hdf5, monkeypatch -): - """Test that the H5Generator correctly counts retries when files are temporarily unavailable.""" - - generator = _get_h5_generator(dummy_hdf5, "data", 1, True, validate=False) - - # Store the original load method - original_load_data = File.load_data - error_count = [0] # Use list to allow modification in closure - - # Create a mock load function that fails a specified number of times - def mock_load_data(self, dtype, indices): - if error_count[0] < mock_error_count: - error_count[0] += 1 - log.debug(f"Simulating I/O error in File.load_data. Error count: {error_count[0]}") - raise OSError(f"Simulated file access error (attempt {error_count[0]})") - # After specified failures, call the original method - return original_load_data(self, dtype, indices) - - # Apply the monkeypatch to the zea.file.File class method - monkeypatch.setattr(File, "load_data", mock_load_data) - - if should_succeed: - # Should succeed after retries - batch = next(iter(generator)) - batch = ops.convert_to_numpy(batch) - assert isinstance(batch, np.ndarray), "Failed to get valid data after retries" - else: - # Should fail after max retries - with pytest.raises(ValueError) as exc_info: - next(iter(generator)) - assert "Failed to complete operation" in str(exc_info.value) - - assert generator.retry_count == expected_retries, ( - f"Expected {expected_retries} retries but got {generator.retry_count}" - ) - - @pytest.mark.usefixtures("dummy_hdf5") def test_random_circle_inclusion_augmentation(dummy_hdf5): """Test RandomCircleInclusion augmentation with dataloader.""" @@ -500,7 +447,7 @@ def test_random_circle_inclusion_augmentation(dummy_hdf5): ] ) - dataset = make_dataloader( + dataset = Dataloader( dummy_hdf5, batch_size=4, key="data", @@ -514,7 +461,7 @@ def test_random_circle_inclusion_augmentation(dummy_hdf5): ) images = next(iter(dataset)) - images_np = np.array(images) + images_np = ops.convert_to_numpy(images) # Output shape should match input shape assert images_np.shape == ( @@ -535,7 +482,7 @@ def test_resize_with_different_shapes(multi_shape_dataset): """Test the dataloader class with different image shapes in a batch.""" # Create a dataloader instance with different image shapes - dataset = make_dataloader( + dataset = Dataloader( multi_shape_dataset, key="data", image_size=(16, 16), @@ -549,10 +496,154 @@ def test_resize_with_different_shapes(multi_shape_dataset): # Get the first batch images = next(iter(dataset)) - images_np = np.array(images) + images_np = ops.convert_to_numpy(images) # Output shape should match input shape assert images_np.shape[-3:-1] == ( 16, 16, ), f"Output shape {images_np.shape} does not match expected (16, 16)" + + +def test_skipped_files_warning(tmp_path): + """Test warning when files have too few frames for n_frames * frame_index_stride.""" + rng = np.random.default_rng(DEFAULT_TEST_SEED) + # Create file with only 1 frame — requesting n_frames=5 should skip it + with h5py.File(tmp_path / "small_0.hdf5", "w") as f: + f.create_dataset("data", data=rng.standard_normal((1, 28, 28))) + + source = H5DataSource( + file_paths=tmp_path, + key="data", + n_frames=5, + frame_index_stride=1, + validate=False, + ) + assert len(source) == 0 + + +def test_limit_n_samples(dummy_hdf5): + """Test H5DataSource with limit_n_samples caps samples.""" + source = H5DataSource( + file_paths=dummy_hdf5, + key="data", + n_frames=1, + limit_n_samples=5, + validate=False, + ) + assert len(source) == 5 + + +def test_cache_hit_and_store(dummy_hdf5): + """Test caching: first access stores in cache, second access hits cache.""" + source = H5DataSource( + file_paths=dummy_hdf5, + key="data", + n_frames=1, + cache=True, + validate=False, + ) + # First access stores in cache + result1 = source[0] + assert 0 in source._data_cache + + # Second access hits cache + result2 = source[0] + np.testing.assert_array_equal(result1, result2) + + +def test_normalization_without_image_range_raises(dummy_hdf5): + """Test that setting normalization_range without image_range raises.""" + with pytest.raises(AssertionError, match="image_range must be set"): + Dataloader( + dummy_hdf5, + key="data", + normalization_range=(0, 1), + image_range=None, + validate=False, + ) + + +def test_num_shards_without_shard_index_raises(dummy_hdf5): + """Test that num_shards > 1 without shard_index raises.""" + with pytest.raises(AssertionError, match="shard_index must be specified"): + Dataloader( + dummy_hdf5, + key="data", + num_shards=2, + validate=False, + ) + + +def test_auto_seed_generation(dummy_hdf5): + """Test that seed is auto-generated when shuffle=True and seed=None.""" + loader = Dataloader( + dummy_hdf5, + key="data", + shuffle=True, + seed=None, + validate=False, + ) + assert loader.seed is not None + + +def test_dataset_property(dummy_hdf5): + """Test the .dataset property returns the underlying MapDataset.""" + loader = Dataloader( + dummy_hdf5, + key="data", + shuffle=False, + validate=False, + ) + assert loader.dataset is not None + + +def test_dataloader_repr(dummy_hdf5): + """Test Dataloader __repr__ includes key information.""" + loader = Dataloader( + dummy_hdf5, + key="data", + shuffle=False, + validate=False, + batch_size=4, + ) + repr_str = repr(loader) + assert "`_. It provides a convenient way to load and preprocess data for machine learning workflows. """ diff --git a/zea/backend/tensorflow/__init__.py b/zea/backend/tensorflow/__init__.py index c4ad9735e..8b48b172e 100644 --- a/zea/backend/tensorflow/__init__.py +++ b/zea/backend/tensorflow/__init__.py @@ -6,12 +6,8 @@ import sys from pathlib import PosixPath -import numpy as np - # Convert PosixPath objects to strings in sys.path # this is necessary due to weird TF bug when importing sys.path = [str(p) if isinstance(p, PosixPath) else p for p in sys.path] import tensorflow as tf # noqa: E402 - -from .dataloader import make_dataloader # noqa: E402 diff --git a/zea/backend/tensorflow/dataloader.py b/zea/backend/tensorflow/dataloader.py deleted file mode 100644 index 713940096..000000000 --- a/zea/backend/tensorflow/dataloader.py +++ /dev/null @@ -1,369 +0,0 @@ -"""HDF5 Tensorflow dataloader. - -Convenient way of loading data from hdf5 files in a ML pipeline. -""" - -from functools import partial -from typing import List - -import keras -import tensorflow as tf -from keras.src.trainers.data_adapters import TFDatasetAdapter - -from zea.data.dataloader import H5Generator -from zea.data.layers import Resizer -from zea.func.tensor import translate -from zea.internal.utils import find_methods_with_return_type - -METHODS_THAT_RETURN_DATASET = find_methods_with_return_type(tf.data.Dataset, "DatasetV2") - - -class TFDatasetToKeras(TFDatasetAdapter): - """Tensorflow Dataset to Keras Dataset. - - This class wraps a tf.data.Dataset object and allows it to be used with Keras backends. - """ - - def __init__(self, dataset): - super().__init__(dataset) - - def __iter__(self): - backend = keras.backend.backend() - if backend == "tensorflow": - return iter(self.get_tf_dataset()) - elif backend == "jax": - return self.get_jax_iterator() - elif backend == "torch": - return iter(self.get_torch_dataloader()) - elif backend == "numpy": - return self.get_numpy_iterator() - else: - raise ValueError( - f"Unsupported backend: {backend}. " - "Please use one of the following: 'tensorflow', 'jax', 'torch', 'numpy'." - ) - - def __len__(self): - return self.num_batches - - def __getattr__(self, name): - # Delegate all calls to self._dataset, and wraps the result in TFDatasetToKeras - if name in METHODS_THAT_RETURN_DATASET: - - def method(*args, **kwargs): - result = getattr(self._dataset, name)(*args, **kwargs) - return TFDatasetToKeras(result) - - return method - else: - return getattr(self._dataset, name) - - -class H5GeneratorTF(H5Generator): - """Adds a tensorflow dtype property and output_signature to the H5Generator class.""" - - @property - def tensorflow_dtype(self): - """ - Extracts one image from the dataset to get the dtype. Converts it to a tensorflow dtype. - """ - out = next(self.iterator()) - if self.return_filename: - out = out[0] - dtype = out.dtype - if "float" in str(dtype): - dtype = tf.float32 - elif "complex" in str(dtype): - dtype = tf.complex64 - elif "uint8" in str(dtype): - dtype = tf.uint8 - else: - raise ValueError(f"Unsupported dtype: {dtype}") - return dtype - - @property - def output_signature(self): - """ - Get the output signature of the generator as a tensorflow `TensorSpec`. - This is useful for creating a `tf.data.Dataset` from the generator. - """ - output_signature = tf.TensorSpec(shape=self.shape, dtype=self.tensorflow_dtype) - if self.return_filename: - output_signature = ( - output_signature, - tf.TensorSpec(shape=(), dtype=tf.string), - ) - return output_signature - - -def _assert_image_range(images, image_range): - # Check if there are outliers in the image range - minval = tf.reduce_min(images) - maxval = tf.reduce_max(images) - _msg = f"Image range {image_range} is not in the range of the data {minval} - {maxval}" - tf.debugging.assert_greater_equal( - minval, - tf.cast(image_range[0], minval.dtype), - message=_msg, - ) - tf.debugging.assert_less_equal( - maxval, - tf.cast(image_range[1], maxval.dtype), - message=_msg, - ) - return images - - -def make_dataloader( - file_paths: List[str], - batch_size: int, - key: str = "data/image", - n_frames: int = 1, - shuffle: bool = True, - return_filename: bool = False, - limit_n_samples: int | None = None, - limit_n_frames: int | None = None, - seed: int | None = None, - drop_remainder: bool = False, - resize_type: str | None = None, - resize_axes: tuple | None = None, - resize_kwargs: dict | None = None, - image_size: tuple | None = None, - image_range: tuple | None = None, - normalization_range: tuple | None = None, - dataset_repetitions: int | None = None, - cache: bool = False, - additional_axes_iter: tuple | None = None, - sort_files: bool = True, - overlapping_blocks: bool = False, - augmentation: callable = None, - assert_image_range: bool = True, - clip_image_range: bool = False, - initial_frame_axis: int = 0, - insert_frame_axis: bool = True, - frame_index_stride: int = 1, - frame_axis: int = -1, - validate: bool = True, - prefetch: bool = True, - shard_index: int | None = None, - num_shards: int = 1, - wrap_in_keras: bool = True, - **kwargs, -) -> tf.data.Dataset: - """Creates a ``tf.data.Dataset`` from .hdf5 files in the specified directory or directories. - - Mimics the native TF function ``tf.keras.utils.image_dataset_from_directory`` - but for .hdf5 files. - - Does the following in order to load a dataset: - - - Find all .hdf5 files in the director(ies) - - Load the data from each file using the specified key - - Apply the following transformations in order (if specified): - - - limit_n_samples - - cache - - shuffle - - shard - - add channel dim - - assert_image_range - - clip_image_range - - resize - - repeat - - batch - - normalize - - augmentation - - prefetch - - tf -> keras tensor - - Args: - file_paths (str or list): Path(s) to the folder(s) or h5 file(s) to load. - batch_size (int): Batch the dataset. - key (str): The key to access the HDF5 dataset. - n_frames (int, optional): Number of frames to load from each hdf5 file. - Defaults to 1. These frames are stacked along the last axis (channel). - shuffle (bool, optional): Shuffle dataset. - return_filename (bool, optional): Return file name with image. Defaults to False. - limit_n_samples (int, optional): Take only a subset of samples. - Useful for debugging. Defaults to None. - limit_n_frames (int, optional): Limit the number of frames to load from each file. - This means n_frames per data file will be used. These will be the first frames in - the file. Defaults to None. - seed (int, optional): Random seed of shuffle. - drop_remainder (bool, optional): Whether the last batch should be dropped. - resize_type (str, optional): Resize type. Defaults to 'center_crop'. - Can be 'center_crop', 'random_crop' or 'resize'. - resize_axes (tuple, optional): Axes to resize along. Should be of length 2 - (height, width) as resizing function only supports 2D resizing / cropping. - Should only be set when your data is more than (h, w, c). Defaults to None. - Note that it considers the axes after inserting the frame axis. - resize_kwargs (dict, optional): Kwargs for the resize function. - image_size (tuple, optional): Resize images to image_size. Should - be of length two (height, width). Defaults to None. - image_range (tuple, optional): Image range. Defaults to (0, 255). - Will always translate from specified image range to normalization range. - If image_range is set to None, no normalization will be done. Note that it does not - clip to the image range, so values outside the image range will be outside the - normalization range! - normalization_range (tuple, optional): Normalization range. Defaults to (0, 1). - See image_range for more info! - dataset_repetitions (int, optional): Repeat dataset. Note that this happens - after sharding, so the shard will be repeated. Defaults to None. - cache (bool, optional): Cache dataset to RAM. - additional_axes_iter (tuple, optional): Additional axes to iterate over - in the dataset. Defaults to None, in that case we only iterate over - the first axis (we assume those contain the frames). - sort_files (bool, optional): Sort files by number. Defaults to True. - overlapping_blocks (bool, optional): If True, blocks overlap by n_frames - 1. - Defaults to False. Has no effect if n_frames = 1. - augmentation (keras.Sequential, optional): Keras augmentation layer. - assert_image_range (bool, optional): Assert that the image range is - within the specified image range. Defaults to True. - clip_image_range (bool, optional): Clip the image range to the specified - image range. Defaults to False. - initial_frame_axis (int, optional): Axis where in the files the frames are stored. - Defaults to 0. - insert_frame_axis (bool, optional): If True, new dimension to stack - frames along will be created. Defaults to True. In that case - frames will be stacked along existing dimension (frame_axis). - frame_index_stride (int, optional): Interval between frames to load. - Defaults to 1. If n_frames > 1, a lower frame rate can be simulated. - frame_axis (int, optional): Dimension to stack frames along. - Defaults to -1. If insert_frame_axis is True, this will be the - new dimension to stack frames along. - validate (bool, optional): Validate if the dataset adheres to the zea format. - Defaults to True. - prefetch (bool, optional): Prefetch the dataset. Defaults to True. - shard_index (int, optional): Index which part of the dataset should be selected. - Can only be used if num_shards is specified. Defaults to None. - See for info: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard - num_shards (int, optional): This is used to divide the dataset into ``num_shards`` parts. - Sharding happens before all other operations. Defaults to 1. - See for info: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard - wrap_in_keras (bool, optional): Wrap dataset in TFDatasetToKeras. Defaults to True. - If True, will convert the dataset that returns backend tensors. - - Returns: - tf.data.Dataset: The constructed dataset. - - """ - # Setup - if normalization_range is not None: - assert image_range is not None, ( - "If normalization_range is set, image_range must be set as well." - ) - - resize_kwargs = resize_kwargs or {} - - if num_shards > 1: - assert shard_index is not None, "shard_index must be specified" - assert shard_index < num_shards, "shard_index must be less than num_shards" - assert shard_index >= 0, "shard_index must be greater than or equal to 0" - - image_extractor = H5GeneratorTF( - file_paths, - key, - n_frames=n_frames, - frame_index_stride=frame_index_stride, - frame_axis=frame_axis, - insert_frame_axis=insert_frame_axis, - initial_frame_axis=initial_frame_axis, - return_filename=return_filename, - shuffle=shuffle, - sort_files=sort_files, - overlapping_blocks=overlapping_blocks, - limit_n_samples=limit_n_samples, - limit_n_frames=limit_n_frames, - seed=seed, - additional_axes_iter=additional_axes_iter, - cache=cache, - validate=validate, - **kwargs, - ) - - # Create dataset - dataset = tf.data.Dataset.from_generator( - image_extractor, output_signature=image_extractor.output_signature - ) - - # Assert cardinality - dataset = dataset.apply(tf.data.experimental.assert_cardinality(len(image_extractor))) - - # Shard dataset - if num_shards > 1: - dataset = dataset.shard(num_shards, shard_index) - - # Define helper function to apply map function to dataset - def dataset_map(dataset, func): - """Does not apply func to filename.""" - if return_filename: - return dataset.map( - lambda x, filename: (func(x), filename), - num_parallel_calls=tf.data.AUTOTUNE, - ) - else: - return dataset.map(func, num_parallel_calls=tf.data.AUTOTUNE) - - # add channel dim - if len(image_extractor.shape) < 3: - dataset = dataset_map(dataset, lambda x: tf.expand_dims(x, axis=-1)) - - # Clip to image range - if clip_image_range and image_range is not None: - dataset = dataset_map( - dataset, - partial( - tf.clip_by_value, - clip_value_min=image_range[0], - clip_value_max=image_range[1], - ), - ) - - # Check if there are outliers in the image range - if assert_image_range and image_range is not None: - dataset = dataset_map(dataset, partial(_assert_image_range, image_range=image_range)) - - if image_size or resize_type: - if frame_axis != -1: - assert resize_axes is not None, ( - "Resizing only works with frame_axis = -1. Alternatively, " - "you can specify resize_axes." - ) - - # Let resizer handle the assertions. - resizer = Resizer( - image_size=image_size, - resize_type=resize_type, - resize_axes=resize_axes, - seed=seed, - **resize_kwargs, - ) - dataset = dataset_map(dataset, resizer) - - # repeat dataset if needed (used for smaller datasets) - if dataset_repetitions: - dataset = dataset.repeat(dataset_repetitions) - - # batch - if batch_size: - dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) - - # normalize - if normalization_range is not None: - dataset = dataset_map( - dataset, - lambda x: translate(x, image_range, normalization_range), - ) - - # augmentation - if augmentation is not None: - dataset = dataset_map(dataset, augmentation) - - # prefetch - if prefetch: - dataset = dataset.prefetch(tf.data.AUTOTUNE) - - if wrap_in_keras: - dataset = TFDatasetToKeras(dataset) - - return dataset diff --git a/zea/backend/tensorflow/utils/callbacks.py b/zea/backend/tensorflow/utils/callbacks.py deleted file mode 100644 index cfe84bb81..000000000 --- a/zea/backend/tensorflow/utils/callbacks.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Useful Tensorflow callbacks - -Not implemented yet. -""" diff --git a/zea/data/__init__.py b/zea/data/__init__.py index fe9963b85..765de857e 100644 --- a/zea/data/__init__.py +++ b/zea/data/__init__.py @@ -47,6 +47,5 @@ generate_zea_dataset, validate_input_data, ) -from .dataloader import H5Generator from .datasets import Dataset, Folder from .file import File, load_file diff --git a/zea/data/dataloader.py b/zea/data/dataloader.py index b96ec74bd..ed749f0bd 100644 --- a/zea/data/dataloader.py +++ b/zea/data/dataloader.py @@ -1,24 +1,41 @@ -""" -H5 dataloader for loading images from zea datasets. +"""H5 dataloader for loading images from zea datasets. + +Example: + .. code-block:: python + + from zea import Dataloader + + loader = Dataloader( + file_paths="/path/to/dataset", + key="data/image", + batch_size=16, + image_range=(-60, 0), + normalization_range=(0, 1), + image_size=(256, 256), + num_threads=16, + ) + + for batch in loader: + # batch is a numpy array of shape (batch_size, 256, 256, 1) + ... """ import re +import threading from itertools import product from pathlib import Path -from typing import List, Tuple, Union +from typing import List +import grain +import keras import numpy as np from zea import log from zea.data.datasets import Dataset, H5FileHandleCache, count_samples_per_directory -from zea.data.file import File -from zea.data.utils import json_dumps -from zea.io_lib import retry_on_io_error -from zea.utils import map_negative_indices +from zea.data.layers import Resizer +from zea.utils import canonicalize_axis, map_negative_indices DEFAULT_NORMALIZATION_RANGE = (0, 1) -MAX_RETRY_ATTEMPTS = 3 -INITIAL_RETRY_DELAY = 0.1 def generate_h5_indices( @@ -65,18 +82,20 @@ def generate_h5_indices( ( "/folder/path_to_file.hdf5", "data/image", - (range(0, 1), slice(None, 256, None), slice(None, 256, None)), + (slice(0, 1, 1), slice(None, 256, None), slice(None, 256, None)), ), ( "/folder/path_to_file.hdf5", "data/image", - (range(1, 2), slice(None, 256, None), slice(None, 256, None)), + (slice(1, 2, 1), slice(None, 256, None), slice(None, 256, None)), ), ..., ] """ - if not limit_n_frames: + if limit_n_frames is None: limit_n_frames = np.inf + else: + assert limit_n_frames > 0, f"limit_n_frames must be > 0, got {limit_n_frames}" assert len(file_paths) == len(file_shapes), "file_paths and file_shapes must have same length" @@ -99,7 +118,7 @@ def generate_h5_indices( file_paths = [file_paths[i] for i in indices_sorting_file_paths] file_shapes = [file_shapes[i] for i in indices_sorting_file_paths] except Exception: - log.warning("H5Generator: Could not sort file_paths by number.") + log.warning("Could not sort file_paths by number.") # block size with stride included block_size = n_frames * frame_index_stride @@ -117,7 +136,7 @@ def axis_indices_files(): # Optionally limit frames to load from each file n_frames_in_file = min(n_frames_in_file, limit_n_frames) indices = [ - list(range(i, i + block_size, frame_index_stride)) + slice(i, i + block_size, frame_index_stride) for i in range(0, n_frames_in_file - block_size + 1, block_step_size) ] yield [indices] @@ -144,7 +163,7 @@ def axis_indices_files(): if skipped_files > 0: log.warning( - f"H5Generator: Skipping {skipped_files} files with not enough frames " + f"Skipping {skipped_files} files with not enough frames " f"which is about {skipped_files / len(file_paths) * 100:.2f}% of the " f"dataset. This can be fine if you expect set `n_frames` and " "`frame_index_stride` to be high. Minimum frames in a file needs to be at " @@ -154,263 +173,519 @@ def axis_indices_files(): return indices -def _h5_reopen_on_io_error( - dataloader_obj: H5FileHandleCache, - file, - key, - indices, - retry_count, - **kwargs, -): - """Reopen the file if an I/O error occurs. - Also removes the file from the cache and try to close file. - """ - file_name = indices[0] - try: - file_handle = dataloader_obj._file_handle_cache.pop(file_name, None) - if file_handle is not None: - file_handle.close() - except Exception: - pass - - log.warning( - f"H5Generator: I/O error occurred while reading file {file_name}. " - f"Retry opening file. Retry count: {retry_count}." - ) - - -class H5Generator(Dataset): - """H5Generator class for iterating over hdf5 files in an advanced way. - Mostly used internally, you might want to use the Dataloader class instead. - Loads one item at a time. Always outputs numpy arrays. +class H5DataSource: + """Thread-safe random-access data source for HDF5 files. + + Implements ``grain.RandomAccessDataSource`` protocol (``__getitem__`` + and ``__len__``) so it can be plugged directly into a + ``grain.MapDataset`` pipeline. + + Each worker thread gets its own ``H5FileHandleCache`` via + ``threading.local()`` so ``h5py`` file handles are never shared across + threads. + + Args: + file_paths: Path(s) to HDF5 directory(ies) or file(s). + key: HDF5 dataset key, e.g. ``"data/image"``. + n_frames: Number of consecutive frames per sample. + frame_index_stride: Stride between frames. + frame_axis: Axis along which frames are stacked in the output. + insert_frame_axis: Whether to insert a new axis for frames. + initial_frame_axis: Source axis that stores frames in the file. + additional_axes_iter: Extra axes to iterate over. + sort_files: Sort files numerically. + overlapping_blocks: Allow overlapping frame blocks. + limit_n_samples: Cap the number of samples. + limit_n_frames: Cap frames loaded per file. + return_filename: Return filename metadata with each sample. + cache: Cache loaded samples to RAM. + validate: Validate dataset against the zea format. """ def __init__( self, - file_paths: List[str], + file_paths: List[str] | str, key: str = "data/image", n_frames: int = 1, - shuffle: bool = True, - return_filename: bool = False, - limit_n_samples: int | None = None, - limit_n_frames: int | None = None, - seed: int | None = None, - cache: bool = False, + frame_index_stride: int = 1, + frame_axis: int = -1, + insert_frame_axis: bool = True, + initial_frame_axis: int = 0, additional_axes_iter: tuple | None = None, sort_files: bool = True, overlapping_blocks: bool = False, - initial_frame_axis: int = 0, - insert_frame_axis: bool = True, - frame_index_stride: int = 1, - frame_axis: int = -1, + limit_n_samples: int | None = None, + limit_n_frames: int | None = None, + return_filename: bool = False, + cache: bool = False, validate: bool = True, **kwargs, ): - super().__init__(file_paths, validate=validate, **kwargs) + self.return_filename = return_filename + self.cache = cache + self._data_cache = {} self.key = key self.n_frames = int(n_frames) self.frame_index_stride = int(frame_index_stride) self.frame_axis = int(frame_axis) self.insert_frame_axis = insert_frame_axis - self.initial_frame_axis = int(initial_frame_axis) - self.return_filename = return_filename - self.shuffle = shuffle - self.sort_files = sort_files - self.overlapping_blocks = overlapping_blocks - self.limit_n_samples = limit_n_samples - self.limit_n_frames = limit_n_frames - self.seed = seed - self.additional_axes_iter = additional_axes_iter or [] assert self.frame_index_stride > 0, ( - f"`frame_index_stride` must be greater than 0, got {self.frame_index_stride}" + f"`frame_index_stride` must be > 0, got {self.frame_index_stride}" ) - assert self.n_frames > 0, f"`n_frames` must be greater than 0, got {self.n_frames}" + assert self.n_frames > 0, f"`n_frames` must be > 0, got {self.n_frames}" - # Extract some general information about the dataset - file_shapes = self.load_file_shapes(key) - image_shapes = np.array(file_shapes) - image_shapes = np.delete( - image_shapes, (self.initial_frame_axis, *self.additional_axes_iter), axis=1 - ) - n_dims = len(image_shapes[0]) + # Discover files and shapes (reuses Dataset machinery) + _dataset = Dataset(file_paths, validate=validate, **kwargs) + self.file_paths = _dataset.file_paths + self.file_shapes = _dataset.load_file_shapes(key) + _dataset.close() - self.equal_file_shapes = np.all(image_shapes == image_shapes[0]) - if not self.equal_file_shapes: - log.warning( - "H5Generator: Not all files have the same shape. " - "This can lead to issues when resizing images later...." - ) - self.shape = np.array([None] * n_dims) - else: - self.shape = np.array(image_shapes[0]) - - if insert_frame_axis: - _frame_axis = map_negative_indices([frame_axis], len(self.shape) + 1) - self.shape = np.insert(self.shape, _frame_axis, 1) - if self.shape[frame_axis]: - self.shape[frame_axis] = self.shape[frame_axis] * n_frames - - # Set random number generator - self.rng = np.random.default_rng(self.seed) + num_dims = len(self.file_shapes[0]) + self.initial_frame_axis = canonicalize_axis(int(initial_frame_axis), num_dims) + self.additional_axes_iter = map_negative_indices(list(additional_axes_iter or []), num_dims) + # Compute per-sample index table self.indices = generate_h5_indices( file_paths=self.file_paths, - file_shapes=file_shapes, + file_shapes=self.file_shapes, n_frames=self.n_frames, frame_index_stride=self.frame_index_stride, key=self.key, initial_frame_axis=self.initial_frame_axis, additional_axes_iter=self.additional_axes_iter, - sort_files=self.sort_files, - overlapping_blocks=self.overlapping_blocks, - limit_n_frames=self.limit_n_frames, + sort_files=sort_files, + overlapping_blocks=overlapping_blocks, + limit_n_frames=limit_n_frames, ) - if not self.shuffle: - log.warning("H5Generator: Not shuffling data.") - - if limit_n_samples: - log.warning( - f"H5Generator: Limiting number of samples to {limit_n_samples} " - f"out of {len(self.indices)}" - ) + if limit_n_samples is not None: + log.info(f"H5DataSource: Limiting to {limit_n_samples} / {len(self.indices)} samples.") self.indices = self.indices[:limit_n_samples] - self.shuffled_items = list(range(len(self.indices))) + # Thread-local file handle caches (one per thread) + self._local = threading.local() + self._all_caches: set[H5FileHandleCache] = set() + self._all_caches_lock = threading.Lock() - # Retry count for I/O errors - self.retry_count = 0 + def __len__(self) -> int: + return len(self.indices) - # Create a cache for the data - self.cache = cache - self._data_cache = {} + def __getitem__(self, index: int): + """Return a single sample as a numpy array. Thread-safe.""" + if self.cache and index in self._data_cache: + return self._data_cache[index] + + file_name, key, indices = self.indices[index] + file_handle_cache = self._get_file_handle_cache() + file = file_handle_cache.get_file(file_name) + + try: + images = file.load_data(key, indices) + except (OSError, IOError): + # Invalidate cache entry and retry once + file_handle_cache.pop(file_name) + file = file_handle_cache.get_file(file_name) + images = file.load_data(key, indices) + + if self.insert_frame_axis: + initial = self.initial_frame_axis + if self.additional_axes_iter: + initial -= sum(ax < self.initial_frame_axis for ax in self.additional_axes_iter) + images = np.moveaxis(images, initial, self.frame_axis) + else: + images = np.concatenate(images, axis=self.frame_axis) - def _get_single_item(self, idx): - # Check if the item is already in the cache - if self.cache and idx in self._data_cache: - return self._data_cache[idx] - - # Get the data - file_name, key, indices = self.indices[idx] - file = self.get_file(file_name) - image = self.load(file, key, indices) - file_data = json_dumps( - { - "fullpath": file.filename, + if self.return_filename: + file_data = { + "fullpath": file.filename, # same as file.path, but str type "filename": file.stem, "indices": indices, } - ) + result = (images, file_data) + else: + result = images if self.cache: - # Store the image and file data in the cache - self._data_cache[idx] = [image, file_data] + self._data_cache[index] = result + + return result + + def __repr__(self) -> str: + return ( + f"H5DataSource(n_samples={len(self)}, n_files={len(self.file_paths)}, key='{self.key}')" + ) - return image, file_data + def _get_file_handle_cache(self) -> H5FileHandleCache: + """Return the file-handle cache for the current thread.""" + if not hasattr(self._local, "cache"): + self._local.cache = H5FileHandleCache() + with self._all_caches_lock: + self._all_caches.add(self._local.cache) + return self._local.cache - def __getitem__(self, index): - image, file_data = self._get_single_item(self.shuffled_items[index]) + def close(self): + """Close all file handles across all threads.""" + with self._all_caches_lock: + for c in self._all_caches: + c.close() + self._all_caches.clear() - if self.return_filename: - return image, file_data - else: - return image - - @retry_on_io_error( - max_retries=MAX_RETRY_ATTEMPTS, - initial_delay=INITIAL_RETRY_DELAY, - retry_action=_h5_reopen_on_io_error, - ) - def load( + +class Dataloader: + """High-performance HDF5 dataloader built on `Grain `_. + + .. code-block:: text + + grain threads (N) → h5py (thread-local handles) → numpy → user + + The entire pipeline runs in numpy — no framework dependency until + you feed tensors to your model. + + Does the following in order to load a dataset: + + - Find all .hdf5 files in the director(ies) + - Load the data from each file using the specified key + - Apply the following transformations in order (if specified): + + - shuffle + - shard + - add channel dim + - clip_image_range + - assert_image_range + - resize + - repeat + - batch + - normalize + - augmentation + + Args: + file_paths: Path(s) to directory(ies) and/or HDF5 file(s). + key: HDF5 dataset key. Default is ``"data/image"``. + batch_size: Batch size. Set to ``None`` to disable batching. + Default is ``16``. + n_frames: Number of consecutive frames per sample. Default is ``1``. + When ``n_frames > 1``, frames are grouped into blocks. + shuffle: Shuffle dataset each epoch. Default is ``True``. + return_filename: Return filename metadata together with each sample. + Default is ``False``. + seed: Random seed used for shuffling. Default is ``None``. + If ``None`` and ``shuffle=True``, a random seed is generated. + limit_n_samples: Limit total number of samples (useful for debugging). + Default is ``None`` (no limit). + limit_n_frames: Limit frames loaded per file to the first N frames. + Default is ``None`` (no limit). + drop_remainder: Drop the final incomplete batch. Default is ``False``. + image_size: Target ``(height, width)``. Default is ``None`` (no resizing). + resize_type: Resize strategy. One of ``"resize"``, ``"center_crop"``, + ``"random_crop"`` or ``"crop_or_pad"``. Default is ``None``, + which resolves to ``"resize"`` when `image_size` is set. + resize_axes: Axes to resize along, must have length 2 (height, width). + Only needed when data has more than ``(h, w, c)`` dimensions. + Axes are interpreted after frame-axis insertion/reordering. + Default is ``None``. + resize_kwargs: Extra keyword arguments passed to ``Resizer``. + Default is ``None``. + image_range: Source value range of images, e.g. ``(-60, 0)``. + Used for clipping/asserting/normalization. Default is ``None``. + normalization_range: Target value range, e.g. ``(0, 1)``. + If set, ``image_range`` must also be set. Default is ``None``. + clip_image_range: Clip values to ``image_range`` before normalization. + Default is ``False``. + assert_image_range: Assert values stay within ``image_range``. + Default is ``True``. + dataset_repetitions: Repeat dataset this many times. Repetition happens + after sharding. Default is ``None`` (no repetition). + cache: Cache loaded samples in RAM. Default is ``False``. + Note that with ``overlapping_blocks=True``, the same frame can be part of multiple + samples, so caching will consume more memory. + additional_axes_iter: Additional axes to iterate over in addition to + ``initial_frame_axis``. Default is ``None``. + sort_files: Sort files numerically before indexing. Default is ``True``. + overlapping_blocks: If ``True``, frame blocks overlap by ``n_frames - 1``. + Has no effect when ``n_frames == 1``. Default is ``False``. + augmentation: Callable applied to each batch after normalization. + Default is ``None``. + initial_frame_axis: Axis in file data that represents frames. + Default is ``0``. + insert_frame_axis: If ``True``, keep per-frame samples and move/insert + the frame dimension at ``frame_axis``. If ``False``, loaded frames + are concatenated along ``frame_axis``. Default is ``True``. + frame_index_stride: Step between selected frames in a block. + Default is ``1``. + frame_axis: Axis along which frames are stacked/placed in output. + Default is ``-1``. + validate: Validate discovered files against the zea format. + Default is ``True``. + prefetch: Enable Grain prefetching for iteration. Default is ``True``. + shard_index: Shard index to select when ``num_shards > 1``. + Must satisfy ``0 <= shard_index < num_shards``. + num_shards: Total number of shards for distributed loading. + Sharding happens before downstream transforms. Default is ``1``. + num_threads: Number of Grain read threads (``0`` means main thread only). + Default is ``16``. + prefetch_buffer_size: Size of the Grain buffer for reading elements per Python + process (not per thread). Useful when reading from a distributed file + system. Default is ``500``. + + Example: + .. code-block:: python + + loader = Dataloader( + file_paths="/data/camus", + key="data/image_sc", + batch_size=32, + image_range=(-60, 0), + normalization_range=(0, 1), + image_size=(256, 256), + ) + for batch in loader: + ... # batch.shape == (32, 256, 256, 1) + """ + + def __init__( self, - file: File, - key: str, - indices: Tuple[Union[list, slice, int], ...] | List[int] | int | None = None, + file_paths: List[str] | str, + key: str = "data/image", + batch_size: int | None = 16, + n_frames: int = 1, + shuffle: bool = True, + return_filename: bool = False, + seed: int | None = None, + limit_n_samples: int | None = None, + limit_n_frames: int | None = None, + drop_remainder: bool = False, + image_size: tuple | None = None, + resize_type: str | None = None, + resize_axes: tuple | None = None, + resize_kwargs: dict | None = None, + image_range: tuple | None = None, + normalization_range: tuple | None = None, + clip_image_range: bool = False, + assert_image_range: bool = True, + dataset_repetitions: int | None = None, + cache: bool = False, + additional_axes_iter: tuple | None = None, + sort_files: bool = True, + overlapping_blocks: bool = False, + augmentation: callable = None, + initial_frame_axis: int = 0, + insert_frame_axis: bool = True, + frame_index_stride: int = 1, + frame_axis: int = -1, + validate: bool = True, + prefetch: bool = True, + shard_index: int | None = None, + num_shards: int = 1, + num_threads: int = 16, + prefetch_buffer_size: int = 500, + **kwargs, ): - """Extract data from hdf5 file. - Args: - file_name (str): name of the file to extract image from. - key (str): key of the hdf5 dataset to grab data from. - indices (tuple): indices to extract image from (tuple of slices) - Returns: - np.ndarray: image extracted from hdf5 file and indexed by indices. - """ - try: - images = file.load_data(key, indices) - except (OSError, IOError): - # Let the decorator handle I/O errors - raise - except Exception as exc: - # For non-I/O errors, provide detailed context - raise ValueError( - f"Could not load image at index {indices} " - f"and file {file.name} of shape {file[key].shape}" - ) from exc + # ── Validation ──────────────────────────────────────────────── + if normalization_range is not None: + assert image_range is not None, ( + "If normalization_range is set, image_range must be set too." + ) + if num_shards > 1: + assert shard_index is not None, "shard_index must be specified" + assert 0 <= shard_index < num_shards - # stack frames along frame_axis - if self.insert_frame_axis: - # move frames axis to self.frame_axis - initial_frame_axis = self.initial_frame_axis - if self.additional_axes_iter: - # offset initial_frame_axis if we have additional axes that are before - # the initial_frame_axis - additional_axes_before = sum( - axis < self.initial_frame_axis for axis in self.additional_axes_iter + resize_kwargs = resize_kwargs or {} + + # ── Store config ────────────────────────────────────────────── + self.batch_size = batch_size + self.return_filename = return_filename + self.num_threads = num_threads + self.prefetch_buffer_size = prefetch_buffer_size + self.prefetch = prefetch + self.shuffle = shuffle + + # Grain requires a concrete seed for shuffle — generate one if needed + if seed is None and shuffle: + seed = int(np.random.default_rng().integers(0, 2**31)) + self.seed = seed + self._rng = np.random.default_rng(seed) + + # ── Data source ─────────────────────────────────────────────── + self.source = H5DataSource( + file_paths=file_paths, + key=key, + n_frames=n_frames, + frame_index_stride=frame_index_stride, + frame_axis=frame_axis, + insert_frame_axis=insert_frame_axis, + initial_frame_axis=initial_frame_axis, + additional_axes_iter=additional_axes_iter, + sort_files=sort_files, + overlapping_blocks=overlapping_blocks, + limit_n_samples=limit_n_samples, + limit_n_frames=limit_n_frames, + return_filename=return_filename, + cache=cache, + validate=validate, + **kwargs, + ) + + # ── Store pipeline config for rebuilding per epoch ──────────── + self._pipeline_cfg = dict( + num_shards=num_shards, + shard_index=shard_index, + clip_image_range=clip_image_range, + assert_image_range=assert_image_range, + image_range=image_range, + normalization_range=normalization_range, + dataset_repetitions=dataset_repetitions, + drop_remainder=drop_remainder, + augmentation=augmentation, + resizer=None, + ) + + # Pre-build the resizer (stateless, reusable across epochs) + if image_size or resize_type: + resize_type = resize_type or "resize" + if frame_axis != -1: + assert resize_axes is not None, ( + "Resizing only works with frame_axis = -1. Alternatively, " + "you can specify resize_axes." ) - initial_frame_axis = initial_frame_axis - additional_axes_before + self._pipeline_cfg["resizer"] = Resizer( + image_size=image_size, + resize_type=resize_type, + resize_axes=resize_axes, + seed=seed, + **resize_kwargs, + ) - images = np.moveaxis(images, initial_frame_axis, self.frame_axis) - else: - # append frames to existing axis - images = np.concatenate(images, axis=self.frame_axis) + self._map_dataset = self._build_pipeline(seed) - return images + def _build_pipeline(self, seed: int): + """Build the Grain MapDataset pipeline with the given shuffle seed.""" + cfg = self._pipeline_cfg - def _shuffle(self): - self.rng.shuffle(self.shuffled_items) - log.info("H5Generator: Shuffled data.") + def _ds_map(ds, fn): + if self.return_filename: + return ds.map(lambda item: (fn(item[0]), item[1])) + return ds.map(fn) - def __len__(self): - return len(self.indices) + ds = grain.MapDataset.source(self.source) - def iterator(self): - """Generator that yields images from the hdf5 files.""" if self.shuffle: - self._shuffle() - for idx in range(len(self)): - yield self[idx] + ds = ds.shuffle(seed=seed) - def __iter__(self): - """ - Generator that yields images from the hdf5 files. + if cfg["num_shards"] > 1: + ds = ds[cfg["shard_index"] :: cfg["num_shards"]] + + ds = _ds_map(ds, self._ensure_channel_dim) + + if cfg["clip_image_range"] and cfg["image_range"] is not None: + lo, hi = cfg["image_range"] + ds = _ds_map(ds, lambda x, _lo=lo, _hi=hi: keras.ops.clip(x, _lo, _hi)) + + if cfg["assert_image_range"] and cfg["image_range"] is not None: + _ir = cfg["image_range"] + ds = _ds_map(ds, lambda x, _r=_ir: Dataloader._assert_image_range(x, _r)) + + if cfg["resizer"] is not None: + ds = _ds_map(ds, cfg["resizer"]) + + if cfg["dataset_repetitions"] is not None: + ds = ds.repeat(num_epochs=cfg["dataset_repetitions"]) + + if self.batch_size is not None: + ds = ds.batch(batch_size=self.batch_size, drop_remainder=cfg["drop_remainder"]) + + if cfg["normalization_range"] is not None: + _ir, _nr = cfg["image_range"], cfg["normalization_range"] + ds = _ds_map(ds, lambda x, _a=_ir, _b=_nr: Dataloader._normalize(x, _a, _b)) + + if cfg["augmentation"] is not None: + ds = _ds_map(ds, cfg["augmentation"]) + + return ds + + @property + def dataset(self): + """The underlying ``grain.MapDataset``.""" + return self._map_dataset + + def to_iter_dataset(self): + """Convert to a ``grain.IterDataset`` with prefetching. + + This is called automatically when you iterate, but you can call + it explicitly if you want to hold onto the ``IterDataset`` object. """ - return self.iterator() - def __repr__(self): - return ( - f"<{self.__class__.__name__} at 0x{id(self):x}: " - f"{len(self)} batches, n_frames={self.n_frames}, key='{self.key}', " - f"shuffle={self.shuffle}, file_paths={len(self.file_paths)}>" + return self._map_dataset.to_iter_dataset( + grain.ReadOptions( + num_threads=self.num_threads, + prefetch_buffer_size=self.prefetch_buffer_size if self.prefetch else 0, + ) ) - def __str__(self): + def __iter__(self): + # Rebuild the pipeline with a fresh seed so each epoch sees a different order + if self.shuffle: + self._map_dataset = self._build_pipeline(seed=int(self._rng.integers(0, 2**31))) + return iter(self.to_iter_dataset()) + + def __len__(self): + """Number of batches (or samples if unbatched).""" + return len(self._map_dataset) + + def __repr__(self): return ( - f"H5Generator with {len(self)} batches from {len(self.file_paths)} files " - f"(key='{self.key}')" + f"" ) + @staticmethod + def _ensure_channel_dim(image): + """Ensure at least 3-D (H, W, C) so batching produces uniform shapes.""" + if len(keras.ops.shape(image)) < 3: + return keras.ops.expand_dims(image, axis=-1) + return image + + @staticmethod + def _assert_image_range(image, image_range): + """Assert that image values are within the specified range.""" + minval = float(keras.ops.min(image)) + maxval = float(keras.ops.max(image)) + if minval < image_range[0]: + raise ValueError( + f"Image min {minval} is below image_range lower bound {image_range[0]}" + ) + if maxval > image_range[1]: + raise ValueError( + f"Image max {maxval} is above image_range upper bound {image_range[1]}" + ) + return image + + @staticmethod + def _normalize(image, image_range, normalization_range): + """Normalize image from image_range to normalization_range.""" + left_min, left_max = image_range + right_min, right_max = normalization_range + scale = (right_max - right_min) / (left_max - left_min) + offset = right_min - scale * left_min + return keras.ops.add(keras.ops.multiply(image, scale), offset) + def summary(self): - """Return a string with dataset statistics and per-directory breakdown.""" - total_samples = len(self.indices) - file_names = [idx[0] for idx in self.indices] - # Try to infer directories from file_names + """Print dataset statistics and per-directory breakdown.""" + src = self.source + total_samples = len(src) + file_names = [idx[0] for idx in src.indices] directories = sorted({str(Path(f).parent) for f in file_names}) samples_per_dir = count_samples_per_directory(file_names, directories) - parts = [f"H5Generator with {total_samples} total samples:"] + parts = [f"Dataloader with {total_samples} total samples:"] for dir_path, count in samples_per_dir.items(): - percentage = (count / total_samples) * 100 if total_samples else 0 - parts.append(f" {dir_path}: {count} samples ({percentage:.1f}%)") + pct = (count / total_samples) * 100 if total_samples else 0 + parts.append(f" {dir_path}: {count} samples ({pct:.1f}%)") print("\n".join(parts)) + + def close(self): + """Release file handles.""" + self.source.close() diff --git a/zea/data/datasets.py b/zea/data/datasets.py index f130d60e3..92978df1b 100644 --- a/zea/data/datasets.py +++ b/zea/data/datasets.py @@ -104,6 +104,15 @@ def get_file(self, file_path) -> File: return self._file_handle_cache[file_path] + def pop(self, file_path): + """Pop a file from the cache and close it.""" + file = self._file_handle_cache.pop(file_path, None) + if file is not None: + try: + file.close() + except Exception: + pass # swallow exceptions during close + def close(self): """Close all cached file handles.""" cache: OrderedDict = getattr(self, "_file_handle_cache", None) diff --git a/zea/data/file.py b/zea/data/file.py index abc2d4241..f0f3936ba 100644 --- a/zea/data/file.py +++ b/zea/data/file.py @@ -31,7 +31,7 @@ def assert_key(file: h5py.File, key: str): class File(h5py.File): """h5py.File in zea format.""" - def __init__(self, name, *args, **kwargs): + def __init__(self, name, mode="r", *args, **kwargs): """Initialize the file. Args: @@ -39,6 +39,7 @@ def __init__(self, name, *args, **kwargs): Can be a string or a Path object. Additionally can be a string with the prefix 'hf://', in which case it will be resolved to a huggingface path. + mode (str, optional): The mode to open the file in. Defaults to "r". *args: Additional arguments to pass to h5py.File. **kwargs: Additional keyword arguments to pass to h5py.File. """ @@ -48,12 +49,12 @@ def __init__(self, name, *args, **kwargs): name = _hf_resolve_path(str(name)) # Disable locking for read mode by default - if "locking" not in kwargs and "mode" in kwargs and kwargs["mode"] == "r": + if "locking" not in kwargs and mode == "r": # If the file is opened in read mode, disable locking kwargs["locking"] = False # Initialize the h5py.File - super().__init__(name, *args, **kwargs) + super().__init__(name, mode, *args, **kwargs) @property def path(self): diff --git a/zea/data/utils.py b/zea/data/utils.py deleted file mode 100644 index 1ca25c8ce..000000000 --- a/zea/data/utils.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Utility functions for zea datasets.""" - -import json -from pathlib import Path - -from keras import ops - - -class ZeaJSONEncoder(json.JSONEncoder): - """Wrapper for json.dumps to encode range and slice objects. - - Example: - >>> import json - >>> from zea.data.utils import ZeaJSONEncoder - >>> json.dumps(range(10), cls=ZeaJSONEncoder) - '{"__type__": "range", "start": 0, "stop": 10, "step": 1}' - - Note: - Probably you would use the `zea.data.dataloader.json_dumps()` - function instead of using this class directly. - """ - - def default(self, o): - if isinstance(o, range): - return { - "__type__": "range", - "start": o.start, - "stop": o.stop, - "step": o.step, - } - if isinstance(o, slice): - return { - "__type__": "slice", - "start": o.start, - "stop": o.stop, - "step": o.step, - } - if isinstance(o, Path): - return str(o) - return super().default(o) - - -def json_dumps(obj): - """Used to serialize objects that contain range and slice objects. - Args: - obj: object to serialize (most likely a dictionary). - Returns: - str: serialized object (json string). - """ - return json.dumps(obj, cls=ZeaJSONEncoder) - - -def json_loads(obj): - """Used to deserialize objects that contain range and slice objects. - Args: - obj: object to deserialize (most likely a json string). - Returns: - object: deserialized object (dictionary). - """ - return json.loads(obj, object_hook=_zea_datasets_json_decoder) - - -def decode_file_info(file_info): - """Decode file info from a json string. - A batch of H5Generator can return a list of file_info that are json strings. - This function decodes the json strings and returns a list of dictionaries - with the information, namely: - - full_path: full path to the file - - file_name: file name - - indices: indices used to extract the image from the file - """ - - if file_info.ndim == 0: - file_info = [file_info] - - decoded_info = [] - for info in file_info: - info = ops.convert_to_numpy(info)[()].decode("utf-8") - decoded_info.append(json_loads(info)) - return decoded_info - - -def _zea_datasets_json_decoder(dct): - """Wrapper for json.loads to decode range and slice objects.""" - if "__type__" in dct: - if dct["__type__"] == "range": - return range(dct["start"], dct["stop"], dct["step"]) - if dct["__type__"] == "slice": - return slice(dct["start"], dct["stop"], dct["step"]) - return dct diff --git a/zea/internal/utils.py b/zea/internal/utils.py index 3efbd2f28..67fb4c42f 100644 --- a/zea/internal/utils.py +++ b/zea/internal/utils.py @@ -216,37 +216,6 @@ def fn_requires_argument(fn, arg_name): return arg_name in params -def find_methods_with_return_type(cls, return_type_hint: str): - """ - Find all methods in a class that have the specified return type hint. - - Args: - cls: The class to inspect. - return_type_hint (str): The return type hint to match. - - Returns: - A list of method names that match the return type hint. - """ - matching_methods = [] - for name, member in inspect.getmembers(cls, predicate=inspect.isfunction): - annotations = getattr(member, "__annotations__", {}) - return_annotation = annotations.get("return") - if return_annotation is None: - continue - - # Convert annotation to string for comparison - if hasattr(return_annotation, "__name__"): - # For types like bool, int, str, custom classes - annotation_str = return_annotation.__name__ - else: - # For string annotations or other types, convert to string - annotation_str = str(return_annotation) - - if annotation_str == return_type_hint: - matching_methods.append(name) - return matching_methods - - def keep_trying(fn, args=None, required_set=None): """Keep trying to run a function until it succeeds. diff --git a/zea/io_lib.py b/zea/io_lib.py index 15e262dce..2de26c503 100644 --- a/zea/io_lib.py +++ b/zea/io_lib.py @@ -3,9 +3,7 @@ Use to quickly read and write files or interact with file system. """ -import functools import os -import time from io import BytesIO from pathlib import Path from typing import Generator @@ -378,65 +376,6 @@ def matplotlib_figure_to_numpy(fig, **kwargs): return image -def retry_on_io_error(max_retries=3, initial_delay=0.5, retry_action=None): - """Decorator to retry functions on I/O errors with exponential backoff. - - Args: - max_retries (int): Maximum number of retry attempts. - initial_delay (float): Initial delay between retries in seconds. - retry_action (callable, optional): Optional function to call before each retry attempt. - If decorating a method: ``retry_action(self, exception, attempt, *args, **kwargs)`` - If decorating a function: ``retry_action(exception, attempt, *args, **kwargs)`` - - Returns: - callable: Decorated function with retry logic. - - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - delay = initial_delay - last_exception = None - - for attempt in range(max_retries): - try: - return func(*args, **kwargs) - except (OSError, IOError) as e: - last_exception = e - - # if args exist and first arg is a class, update retry count of that method - if args and hasattr(args[0], "retry_count"): - args[0].retry_count = attempt + 1 - - if attempt < max_retries - 1: - # Execute custom retry action if provided - if retry_action: - # Pass all original arguments to retry_action - retry_action( - *args, - exception=e, - retry_count=attempt, - **kwargs, - ) - - time.sleep(delay) - - else: - # Last attempt failed - log.error(f"Failed after {max_retries} attempts: {e}") - - # If we've exhausted all retries - raise ValueError( - f"Failed to complete operation after {max_retries} attempts. " - f"Last error: {last_exception}" - ) - - return wrapper - - return decorator - - def _convert_image_mode(img, mode="L"): """Convert a PIL Image to the specified mode and return as numpy array.""" if mode not in {"L", "RGB"}: