Skip to content

Commit 0a99daf

Browse files
committed
Merge remote-tracking branch 'origin/main' into feature/joss-paper
2 parents f9ec34e + 3d06e3d commit 0a99daf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+4768
-1109
lines changed

.github/workflows/tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ jobs:
158158
- name: Set up Git
159159
run: git config --global safe.directory $GITHUB_WORKSPACE
160160

161-
- name: Build documentation with Sphinx (fail on warnings)
161+
- name: Build documentation with Sphinx (fails on warnings)
162162
run: |
163163
docker run --rm \
164164
-v "${{ github.workspace }}:/zea" \

docs/source/environment.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ Here are the environment variables that ``zea`` uses at runtime. Arguably the mo
3131
- Timeout in seconds for calling ``nvidia-smi`` to get GPU information during :func:`zea.init_device`.
3232
- ``30``
3333
- Any positive integer, or ``<= 0`` to disable timeout.
34+
* - ``ZEA_DOWNLOAD_TIMEOUT``
35+
- Timeout in seconds for downloading files, e.g. during dataset conversion.
36+
- ``60``
37+
- Any positive integer, or ``<= 0`` to disable timeout.
3438
* - ``ZEA_FIND_H5_SHAPES_PARALLEL``
3539
- If set to ``1``, will use parallel processing when searching for HDF5 file shapes.
3640
- ``1``

docs/source/getting-started.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ Let's take a quick look at how to use ``zea`` to load and process ultrasound dat
1616
zea.init_device()
1717
1818
# loading a config file from Hugging Face, but can also load a local config file
19-
config = zea.Config.from_hf(
20-
"zeahub/configs", "config_picmus_rf.yaml", repo_type="dataset",
21-
)
19+
config = zea.Config.from_path("hf://zeahub/configs/config_picmus_rf.yaml")
2220
2321
path = config.data.dataset_folder + "/" + config.data.file_path
2422
with zea.File(path) as file:
@@ -50,7 +48,7 @@ Similarly, we can easily load one of the pretrained models from the :mod:`zea.mo
5048
model = EchoNetDynamic.from_preset("echonet-dynamic")
5149
5250
# we'll load a single file from the dataset
53-
with zea.Dataset("hf://zeahub/camus-sample/", "image_sc") as dataset:
51+
with zea.Dataset("hf://zeahub/camus-sample/") as dataset:
5452
file = dataset[0]
5553
image = file.load_data("image_sc", indices=0)
5654

docs/source/notebooks/agent/agent_example.ipynb

Lines changed: 67 additions & 7 deletions
Large diffs are not rendered by default.
38.5 KB
Loading

docs/source/notebooks/data/zea_data_example.ipynb

Lines changed: 30 additions & 28 deletions
Large diffs are not rendered by default.

docs/source/notebooks/data/zea_simulation_example.ipynb

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -266,23 +266,25 @@
266266
"positions = phantoms.fish()\n",
267267
"magnitudes = np.ones(len(positions), dtype=np.float32)\n",
268268
"\n",
269-
"rf_data = simulate_rf(\n",
270-
" scatterer_positions=positions,\n",
271-
" scatterer_magnitudes=magnitudes,\n",
272-
" probe_geometry=probe.probe_geometry,\n",
273-
" apply_lens_correction=scan.apply_lens_correction,\n",
274-
" lens_thickness=scan.lens_thickness,\n",
275-
" lens_sound_speed=scan.lens_sound_speed,\n",
276-
" sound_speed=scan.sound_speed,\n",
277-
" n_ax=scan.n_ax,\n",
278-
" center_frequency=probe.center_frequency,\n",
279-
" sampling_frequency=probe.sampling_frequency,\n",
280-
" t0_delays=scan.t0_delays,\n",
281-
" initial_times=scan.initial_times,\n",
282-
" element_width=scan.element_width,\n",
283-
" attenuation_coef=scan.attenuation_coef,\n",
284-
" tx_apodizations=scan.tx_apodizations,\n",
285-
")\n",
269+
"simulation_args = {\n",
270+
" \"scatterer_positions\": positions,\n",
271+
" \"scatterer_magnitudes\": magnitudes,\n",
272+
" \"probe_geometry\": probe.probe_geometry,\n",
273+
" \"apply_lens_correction\": scan.apply_lens_correction,\n",
274+
" \"lens_thickness\": scan.lens_thickness,\n",
275+
" \"lens_sound_speed\": scan.lens_sound_speed,\n",
276+
" \"sound_speed\": scan.sound_speed,\n",
277+
" \"n_ax\": scan.n_ax,\n",
278+
" \"center_frequency\": probe.center_frequency,\n",
279+
" \"sampling_frequency\": probe.sampling_frequency,\n",
280+
" \"t0_delays\": scan.t0_delays,\n",
281+
" \"initial_times\": scan.initial_times,\n",
282+
" \"element_width\": scan.element_width,\n",
283+
" \"attenuation_coef\": scan.attenuation_coef,\n",
284+
" \"tx_apodizations\": scan.tx_apodizations,\n",
285+
"}\n",
286+
"\n",
287+
"rf_data = simulate_rf(**simulation_args)\n",
286288
"print(\"Simulated RF data shape:\", rf_data.shape)"
287289
]
288290
},
@@ -334,9 +336,9 @@
334336
"name": "stdout",
335337
"output_type": "stream",
336338
"text": [
337-
"\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m No transmit origins provided, using zeros\n",
338339
"\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m No azimuth angles provided, using zeros\n",
339-
"\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[33mDEBUG\u001b[0m [zea.Pipeline] The following input keys are not used by the pipeline: {'center_frequency', 'n_el', 'zlims', 'xlims'}. Make sure this is intended. This warning will only be shown once.\n"
340+
"\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m No transmit origins provided, using zeros\n",
341+
"\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[33mDEBUG\u001b[0m [zea.Pipeline] The following input keys are not used by the pipeline: {'center_frequency', 'xlims', 'n_el', 'zlims'}. Make sure this is intended. This warning will only be shown once.\n"
340342
]
341343
}
342344
],
@@ -384,6 +386,99 @@
384386
"source": [
385387
"That's it! You have now simulated ultrasound RF data and reconstructed a B-mode image using `zea`."
386388
]
389+
},
390+
{
391+
"cell_type": "markdown",
392+
"id": "c6dc06de",
393+
"metadata": {},
394+
"source": [
395+
"## Speedup with Just-In-Time compilation (JIT)\n",
396+
"\n",
397+
"The `simulate_rf` function took quite some time to compute in this example. Larger experiments with more point scatterers or acquisitions can execute very slowly. In this case, it is advised to [JIT-compile](https://docs.jax.dev/en/latest/jit-compilation.html) the `simulate_rf` function. The way you do this depends on which machine learning backend (e.g., JAX, PyTorch, TensorFlow) you are using (see [documentation](../../installation.rst#Backend) for details). Starting with JAX, you can simply wrap the function with `jax.jit` as follows:\n",
398+
"\n",
399+
"**JAX**"
400+
]
401+
},
402+
{
403+
"cell_type": "code",
404+
"execution_count": 12,
405+
"id": "8aa8873d",
406+
"metadata": {},
407+
"outputs": [],
408+
"source": [
409+
"from jax import jit\n",
410+
"\n",
411+
"simulate_rf_jit = jit(simulate_rf, static_argnames=[\"apply_lens_correction\", \"n_ax\"])"
412+
]
413+
},
414+
{
415+
"cell_type": "markdown",
416+
"id": "89b0adf8",
417+
"metadata": {},
418+
"source": [
419+
"Let's execute and time the JIT versus non-JIT versions of the `simulate_rf` function to see the speedup."
420+
]
421+
},
422+
{
423+
"cell_type": "code",
424+
"execution_count": 13,
425+
"id": "83453631",
426+
"metadata": {},
427+
"outputs": [
428+
{
429+
"name": "stdout",
430+
"output_type": "stream",
431+
"text": [
432+
"\u001b[1mFunction Timing Statistics\u001b[0m\n",
433+
"=====================================================================================================\n",
434+
"\u001b[36mFunction\u001b[0m \u001b[32mMean\u001b[0m \u001b[32mMedian\u001b[0m \u001b[32mStd Dev\u001b[0m \u001b[33mMin\u001b[0m \u001b[33mMax\u001b[0m \u001b[35mCount\u001b[0m \n",
435+
"-----------------------------------------------------------------------------------------------------\n",
436+
"\u001b[36msimulate_rf\u001b[0m \u001b[32m0.220066\u001b[0m \u001b[32m0.218923\u001b[0m \u001b[32m0.021850\u001b[0m \u001b[33m0.189399\u001b[0m \u001b[33m0.255911\u001b[0m \u001b[35m30\u001b[0m \n",
437+
"\u001b[36msimulate_rf (JIT)\u001b[0m \u001b[32m0.004081\u001b[0m \u001b[32m0.003444\u001b[0m \u001b[32m0.003207\u001b[0m \u001b[33m0.003159\u001b[0m \u001b[33m0.020947\u001b[0m \u001b[35m30\u001b[0m \n"
438+
]
439+
}
440+
],
441+
"source": [
442+
"from zea.utils import FunctionTimer\n",
443+
"\n",
444+
"# Warm-up JIT compilation before benchmarking\n",
445+
"simulate_rf_jit(**simulation_args)\n",
446+
"\n",
447+
"timer = FunctionTimer()\n",
448+
"timed_rf = timer(simulate_rf, name=\"simulate_rf\")\n",
449+
"timed_rf_jit = timer(simulate_rf_jit, name=\"simulate_rf (JIT)\")\n",
450+
"\n",
451+
"for _ in range(30):\n",
452+
" timed_rf_jit(**simulation_args)\n",
453+
" timed_rf(**simulation_args)\n",
454+
"\n",
455+
"timer.print()"
456+
]
457+
},
458+
{
459+
"cell_type": "markdown",
460+
"id": "72b4387a",
461+
"metadata": {},
462+
"source": [
463+
"If you are using another backend, a similar approach can be taken:"
464+
]
465+
},
466+
{
467+
"cell_type": "markdown",
468+
"id": "4eba2540",
469+
"metadata": {},
470+
"source": [
471+
"**PyTorch**\n",
472+
"```python\n",
473+
"import torch\n",
474+
"simulate_rf_jit = torch.jit.script(simulate_rf)\n",
475+
"```\n",
476+
"**TensorFlow**\n",
477+
"```python\n",
478+
"import tensorflow as tf\n",
479+
"simulate_rf_jit = tf.function(simulate_rf)\n",
480+
"```"
481+
]
387482
}
388483
],
389484
"metadata": {

docs/source/notebooks/models/diffusion_model_example.ipynb

Lines changed: 9 additions & 16 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)