|
266 | 266 | "positions = phantoms.fish()\n", |
267 | 267 | "magnitudes = np.ones(len(positions), dtype=np.float32)\n", |
268 | 268 | "\n", |
269 | | - "rf_data = simulate_rf(\n", |
270 | | - " scatterer_positions=positions,\n", |
271 | | - " scatterer_magnitudes=magnitudes,\n", |
272 | | - " probe_geometry=probe.probe_geometry,\n", |
273 | | - " apply_lens_correction=scan.apply_lens_correction,\n", |
274 | | - " lens_thickness=scan.lens_thickness,\n", |
275 | | - " lens_sound_speed=scan.lens_sound_speed,\n", |
276 | | - " sound_speed=scan.sound_speed,\n", |
277 | | - " n_ax=scan.n_ax,\n", |
278 | | - " center_frequency=probe.center_frequency,\n", |
279 | | - " sampling_frequency=probe.sampling_frequency,\n", |
280 | | - " t0_delays=scan.t0_delays,\n", |
281 | | - " initial_times=scan.initial_times,\n", |
282 | | - " element_width=scan.element_width,\n", |
283 | | - " attenuation_coef=scan.attenuation_coef,\n", |
284 | | - " tx_apodizations=scan.tx_apodizations,\n", |
285 | | - ")\n", |
| 269 | + "simulation_args = {\n", |
| 270 | + " \"scatterer_positions\": positions,\n", |
| 271 | + " \"scatterer_magnitudes\": magnitudes,\n", |
| 272 | + " \"probe_geometry\": probe.probe_geometry,\n", |
| 273 | + " \"apply_lens_correction\": scan.apply_lens_correction,\n", |
| 274 | + " \"lens_thickness\": scan.lens_thickness,\n", |
| 275 | + " \"lens_sound_speed\": scan.lens_sound_speed,\n", |
| 276 | + " \"sound_speed\": scan.sound_speed,\n", |
| 277 | + " \"n_ax\": scan.n_ax,\n", |
| 278 | + " \"center_frequency\": probe.center_frequency,\n", |
| 279 | + " \"sampling_frequency\": probe.sampling_frequency,\n", |
| 280 | + " \"t0_delays\": scan.t0_delays,\n", |
| 281 | + " \"initial_times\": scan.initial_times,\n", |
| 282 | + " \"element_width\": scan.element_width,\n", |
| 283 | + " \"attenuation_coef\": scan.attenuation_coef,\n", |
| 284 | + " \"tx_apodizations\": scan.tx_apodizations,\n", |
| 285 | + "}\n", |
| 286 | + "\n", |
| 287 | + "rf_data = simulate_rf(**simulation_args)\n", |
286 | 288 | "print(\"Simulated RF data shape:\", rf_data.shape)" |
287 | 289 | ] |
288 | 290 | }, |
|
334 | 336 | "name": "stdout", |
335 | 337 | "output_type": "stream", |
336 | 338 | "text": [ |
337 | | - "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m No transmit origins provided, using zeros\n", |
338 | 339 | "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m No azimuth angles provided, using zeros\n", |
339 | | - "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[33mDEBUG\u001b[0m [zea.Pipeline] The following input keys are not used by the pipeline: {'center_frequency', 'n_el', 'zlims', 'xlims'}. Make sure this is intended. This warning will only be shown once.\n" |
| 340 | + "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m No transmit origins provided, using zeros\n", |
| 341 | + "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[33mDEBUG\u001b[0m [zea.Pipeline] The following input keys are not used by the pipeline: {'center_frequency', 'xlims', 'n_el', 'zlims'}. Make sure this is intended. This warning will only be shown once.\n" |
340 | 342 | ] |
341 | 343 | } |
342 | 344 | ], |
|
384 | 386 | "source": [ |
385 | 387 | "That's it! You have now simulated ultrasound RF data and reconstructed a B-mode image using `zea`." |
386 | 388 | ] |
| 389 | + }, |
| 390 | + { |
| 391 | + "cell_type": "markdown", |
| 392 | + "id": "c6dc06de", |
| 393 | + "metadata": {}, |
| 394 | + "source": [ |
| 395 | + "## Speedup with Just-In-Time compilation (JIT)\n", |
| 396 | + "\n", |
| 397 | + "The `simulate_rf` function took quite some time to compute in this example. Larger experiments with more point scatterers or acquisitions can execute very slowly. In this case, it is advised to [JIT-compile](https://docs.jax.dev/en/latest/jit-compilation.html) the `simulate_rf` function. The way you do this depends on which machine learning backend (e.g., JAX, PyTorch, TensorFlow) you are using (see [documentation](../../installation.rst#Backend) for details). Starting with JAX, you can simply wrap the function with `jax.jit` as follows:\n", |
| 398 | + "\n", |
| 399 | + "**JAX**" |
| 400 | + ] |
| 401 | + }, |
| 402 | + { |
| 403 | + "cell_type": "code", |
| 404 | + "execution_count": 12, |
| 405 | + "id": "8aa8873d", |
| 406 | + "metadata": {}, |
| 407 | + "outputs": [], |
| 408 | + "source": [ |
| 409 | + "from jax import jit\n", |
| 410 | + "\n", |
| 411 | + "simulate_rf_jit = jit(simulate_rf, static_argnames=[\"apply_lens_correction\", \"n_ax\"])" |
| 412 | + ] |
| 413 | + }, |
| 414 | + { |
| 415 | + "cell_type": "markdown", |
| 416 | + "id": "89b0adf8", |
| 417 | + "metadata": {}, |
| 418 | + "source": [ |
| 419 | + "Let's execute and time the JIT versus non-JIT versions of the `simulate_rf` function to see the speedup." |
| 420 | + ] |
| 421 | + }, |
| 422 | + { |
| 423 | + "cell_type": "code", |
| 424 | + "execution_count": 13, |
| 425 | + "id": "83453631", |
| 426 | + "metadata": {}, |
| 427 | + "outputs": [ |
| 428 | + { |
| 429 | + "name": "stdout", |
| 430 | + "output_type": "stream", |
| 431 | + "text": [ |
| 432 | + "\u001b[1mFunction Timing Statistics\u001b[0m\n", |
| 433 | + "=====================================================================================================\n", |
| 434 | + "\u001b[36mFunction\u001b[0m \u001b[32mMean\u001b[0m \u001b[32mMedian\u001b[0m \u001b[32mStd Dev\u001b[0m \u001b[33mMin\u001b[0m \u001b[33mMax\u001b[0m \u001b[35mCount\u001b[0m \n", |
| 435 | + "-----------------------------------------------------------------------------------------------------\n", |
| 436 | + "\u001b[36msimulate_rf\u001b[0m \u001b[32m0.220066\u001b[0m \u001b[32m0.218923\u001b[0m \u001b[32m0.021850\u001b[0m \u001b[33m0.189399\u001b[0m \u001b[33m0.255911\u001b[0m \u001b[35m30\u001b[0m \n", |
| 437 | + "\u001b[36msimulate_rf (JIT)\u001b[0m \u001b[32m0.004081\u001b[0m \u001b[32m0.003444\u001b[0m \u001b[32m0.003207\u001b[0m \u001b[33m0.003159\u001b[0m \u001b[33m0.020947\u001b[0m \u001b[35m30\u001b[0m \n" |
| 438 | + ] |
| 439 | + } |
| 440 | + ], |
| 441 | + "source": [ |
| 442 | + "from zea.utils import FunctionTimer\n", |
| 443 | + "\n", |
| 444 | + "# Warm-up JIT compilation before benchmarking\n", |
| 445 | + "simulate_rf_jit(**simulation_args)\n", |
| 446 | + "\n", |
| 447 | + "timer = FunctionTimer()\n", |
| 448 | + "timed_rf = timer(simulate_rf, name=\"simulate_rf\")\n", |
| 449 | + "timed_rf_jit = timer(simulate_rf_jit, name=\"simulate_rf (JIT)\")\n", |
| 450 | + "\n", |
| 451 | + "for _ in range(30):\n", |
| 452 | + " timed_rf_jit(**simulation_args)\n", |
| 453 | + " timed_rf(**simulation_args)\n", |
| 454 | + "\n", |
| 455 | + "timer.print()" |
| 456 | + ] |
| 457 | + }, |
| 458 | + { |
| 459 | + "cell_type": "markdown", |
| 460 | + "id": "72b4387a", |
| 461 | + "metadata": {}, |
| 462 | + "source": [ |
| 463 | + "If you are using another backend, a similar approach can be taken:" |
| 464 | + ] |
| 465 | + }, |
| 466 | + { |
| 467 | + "cell_type": "markdown", |
| 468 | + "id": "4eba2540", |
| 469 | + "metadata": {}, |
| 470 | + "source": [ |
| 471 | + "**PyTorch**\n", |
| 472 | + "```python\n", |
| 473 | + "import torch\n", |
| 474 | + "simulate_rf_jit = torch.jit.script(simulate_rf)\n", |
| 475 | + "```\n", |
| 476 | + "**TensorFlow**\n", |
| 477 | + "```python\n", |
| 478 | + "import tensorflow as tf\n", |
| 479 | + "simulate_rf_jit = tf.function(simulate_rf)\n", |
| 480 | + "```" |
| 481 | + ] |
387 | 482 | } |
388 | 483 | ], |
389 | 484 | "metadata": { |
|
0 commit comments