Skip to content

Commit cdc7a5b

Browse files
authored
Accelerated optika.materials.snells_law() using Numba. (#147)
1 parent 63b74a2 commit cdc7a5b

File tree

5 files changed

+95
-33
lines changed

5 files changed

+95
-33
lines changed

optika/materials/_snells_law.py

Lines changed: 95 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from __future__ import annotations
1+
import math
22
import numpy as np
3+
import numba as nb
34
import astropy.units as u
45
import named_arrays as na
56

@@ -38,20 +39,17 @@ def snells_law_scalar(
3839

3940

4041
def snells_law(
41-
wavelength: u.Quantity | na.AbstractScalar,
4242
direction: na.AbstractCartesian3dVectorArray,
4343
index_refraction: float | na.AbstractScalar,
4444
index_refraction_new: float | na.AbstractScalar,
45-
normal: None | na.AbstractCartesian3dVectorArray,
45+
normal: None | na.AbstractCartesian3dVectorArray = None,
4646
is_mirror: bool | na.AbstractScalar = False,
4747
) -> na.Cartesian3dVectorArray:
4848
r"""
4949
A `vector form of Snell's law <https://en.wikipedia.org/wiki/Snell%27s_law#Vector_form>`_.
5050
5151
Parameters
5252
----------
53-
wavelength
54-
The wavelength of the incoming light
5553
direction
5654
The propagation direction of the incoming light
5755
index_refraction
@@ -90,7 +88,6 @@ def snells_law(
9088
# Define the keyword arguments that are common
9189
# to both the reflected and transmitted ray
9290
kwargs = dict(
93-
wavelength=350 * u.nm,
9491
direction=direction,
9592
index_refraction=1,
9693
normal=na.Cartesian3dVectorArray(0, 0, 1),
@@ -271,22 +268,99 @@ def snells_law(
271268
\pm \sqrt{\left( n_2 / n_1 \right)^2 + (\mathbf{k}_1 \cdot \hat{\mathbf{n}})^2 - k_1^2 }
272269
\right) \hat{\mathbf{n}} \right]}
273270
"""
274-
a = direction
275-
n1 = index_refraction # noqa: F841
276-
n2 = index_refraction_new # noqa: F841
277-
278271
if normal is None:
279272
normal = na.Cartesian3dVectorArray(0, 0, -1)
280273

281-
a_x = a.x # noqa: F841
282-
a_y = a.y # noqa: F841
283-
a_z = a.z # noqa: F841
284-
u_x = normal.x # noqa: F841
285-
u_y = normal.y # noqa: F841
286-
u_z = normal.z # noqa: F841
287-
288-
return na.numexpr.evaluate(
289-
"(n1 / n2) * (a + (-(a_x*u_x + a_y*u_y + a_z*u_z) + sign(-(a_x*u_x + a_y*u_y + a_z*u_z)) "
290-
"* (2 * is_mirror - 1) * sqrt(1 / (n1 / n2)**2 + (a_x*u_x + a_y*u_y + a_z*u_z)**2"
291-
"- (a_x*a_x + a_y*a_y + a_z*a_z))) * normal)"
274+
direction = direction << u.dimensionless_unscaled
275+
index_refraction = index_refraction << u.dimensionless_unscaled
276+
index_refraction_new = index_refraction_new << u.dimensionless_unscaled
277+
normal = normal << u.dimensionless_unscaled
278+
279+
b_x, b_y, b_z = _snells_law_numba(
280+
direction.x.value,
281+
direction.y.value,
282+
direction.z.value,
283+
index_refraction.value,
284+
index_refraction_new.value,
285+
normal.x.value,
286+
normal.y.value,
287+
normal.z.value,
288+
is_mirror,
292289
)
290+
291+
return na.Cartesian3dVectorArray(b_x, b_y, b_z)
292+
293+
294+
@nb.guvectorize(
295+
[
296+
"void(float64,float64,float64,float64,float64,float64,float64,float64,bool,float64[:],float64[:],float64[:])"
297+
],
298+
"(),(),(),(),(),(),(),(),()->(),(),()",
299+
target="parallel",
300+
nopython=True,
301+
cache=True,
302+
)
303+
def _snells_law_numba(
304+
direction_x: float,
305+
direction_y: float,
306+
direction_z: float,
307+
index_refraction: float,
308+
index_refraction_new: float,
309+
normal_x: float,
310+
normal_y: float,
311+
normal_z: float,
312+
is_mirror: bool,
313+
result_x: np.ndarray,
314+
result_y: np.ndarray,
315+
result_z: np.ndarray,
316+
): # pragma: nocover
317+
"""
318+
A :mod:`numba`-accelerated version of Snell's law.
319+
320+
Parameters
321+
----------
322+
direction_x
323+
The :math:`x` component of the propagation direction of the incident light.
324+
direction_y
325+
The :math:`y` component of the propagation direction of the incident light.
326+
direction_z
327+
The :math:`z` component of the propagation direction of the incident light.
328+
index_refraction
329+
The index of refraction of the current medium.
330+
index_refraction_new
331+
The index of refraction of the new medium.
332+
normal_x
333+
The :math:`x` component of the vector perpendicular to the interface.
334+
normal_y
335+
The :math:`y` component of the vector perpendicular to the interface.
336+
normal_z
337+
The :math:`z` component of the vector perpendicular to the interface.
338+
is_mirror
339+
Whether the incident light is reflected or not.
340+
"""
341+
a_x = direction_x
342+
a_y = direction_y
343+
a_z = direction_z
344+
345+
n1 = index_refraction
346+
n2 = index_refraction_new
347+
348+
u_x = normal_x
349+
u_y = normal_y
350+
u_z = normal_z
351+
352+
a2 = a_x * a_x + a_y * a_y + a_z * a_z
353+
354+
r = n1 / n2
355+
r2 = r * r
356+
357+
au = a_x * u_x + a_y * u_y + a_z * u_z
358+
au2 = au * au
359+
360+
sgn = -math.copysign(1, au)
361+
362+
d = -au + sgn * (2 * is_mirror - 1) * math.sqrt(1 / r2 + au2 - a2)
363+
364+
result_x[:] = r * (a_x + d * u_x)
365+
result_y[:] = r * (a_y + d * u_y)
366+
result_z[:] = r * (a_z + d * u_z)

optika/materials/_tests/test_snells_law.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ def test_snells_law_scalar(
4646
assert np.allclose(result, result_expected)
4747

4848

49-
@pytest.mark.parametrize(
50-
argnames="wavelength",
51-
argvalues=[
52-
350 * u.nm,
53-
na.linspace(300 * u.nm, 400 * u.nm, axis="wavelength", num=3),
54-
],
55-
)
5649
@pytest.mark.parametrize(
5750
argnames="direction",
5851
argvalues=[
@@ -87,15 +80,13 @@ def test_snells_law_scalar(
8780
],
8881
)
8982
def test_snells_law(
90-
wavelength: u.Quantity | na.AbstractScalar,
9183
direction: na.AbstractCartesian3dVectorArray,
9284
index_refraction: float | na.AbstractScalar,
9385
index_refraction_new: float | na.AbstractScalar,
9486
normal: None | na.AbstractCartesian3dVectorArray,
9587
is_mirror: bool | na.AbstractScalar,
9688
):
9789
result = optika.materials.snells_law(
98-
wavelength=wavelength,
9990
direction=direction,
10091
index_refraction=index_refraction,
10192
index_refraction_new=index_refraction_new,

optika/materials/matrices.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ def transfer(
320320
"""
321321

322322
direction_internal = snells_law(
323-
wavelength=wavelength,
324323
direction=direction,
325324
index_refraction=1,
326325
index_refraction_new=np.real(n),

optika/rulings/_spacing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ class HolographicRulingSpacing(
191191
192192
# Compute the output direction of the diffracted rays.
193193
direction_output = optika.materials.snells_law(
194-
wavelength=wavelength,
195194
direction=direction_input,
196195
index_refraction=1,
197196
index_refraction_new=1,

optika/surfaces.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def propagate_rays(
163163
wavelength_2 = wavelength_1 / r
164164

165165
b = optika.materials.snells_law(
166-
wavelength=wavelength_1,
167166
direction=a,
168167
index_refraction=n1,
169168
index_refraction_new=n2,

0 commit comments

Comments
 (0)