Skip to content

Commit a27a6a3

Browse files
authored
Optimize optika.materials.snells_law() using numexpr (#142)
1 parent eca44b7 commit a27a6a3

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

optika/materials/_snells_law.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,19 +272,21 @@ def snells_law(
272272
\right) \hat{\mathbf{n}} \right]}
273273
"""
274274
a = direction
275-
n1 = index_refraction
276-
n2 = index_refraction_new
275+
n1 = index_refraction # noqa: F841
276+
n2 = index_refraction_new # noqa: F841
277277

278278
if normal is None:
279279
normal = na.Cartesian3dVectorArray(0, 0, -1)
280280

281-
r = n1 / n2
282-
283-
c = -a @ normal
284-
285-
mirror = np.sign(c) * (2 * is_mirror - 1)
286-
287-
t = np.sqrt(np.square(1 / r) + np.square(c) - np.square(a.length))
288-
b = r * (a + (c + mirror * t) * normal)
289-
290-
return b
281+
a_x = a.x # noqa: F841
282+
a_y = a.y # noqa: F841
283+
a_z = a.z # noqa: F841
284+
u_x = normal.x # noqa: F841
285+
u_y = normal.y # noqa: F841
286+
u_z = normal.z # noqa: F841
287+
288+
return na.numexpr.evaluate(
289+
"(n1 / n2) * (a + (-(a_x*u_x + a_y*u_y + a_z*u_z) + sign(-(a_x*u_x + a_y*u_y + a_z*u_z)) "
290+
"* (2 * is_mirror - 1) * sqrt(1 / (n1 / n2)**2 + (a_x*u_x + a_y*u_y + a_z*u_z)**2"
291+
"- (a_x*a_x + a_y*a_y + a_z*a_z))) * normal)"
292+
)

optika/materials/_tests/test_snells_law.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_snells_law_scalar(
5757
argnames="direction",
5858
argvalues=[
5959
na.Cartesian3dVectorArray(0, 0, 1),
60+
optika.direction(na.Cartesian2dVectorArray(1 * u.deg, 2 * u.deg)),
6061
],
6162
)
6263
@pytest.mark.parametrize(
@@ -101,12 +102,28 @@ def test_snells_law(
101102
normal=normal,
102103
is_mirror=is_mirror,
103104
)
105+
106+
a = direction
107+
n1 = index_refraction # noqa: F841
108+
n2 = index_refraction_new # noqa: F841
109+
104110
if normal is None:
105111
normal = na.Cartesian3dVectorArray(0, 0, -1)
106112

113+
r = n1 / n2
114+
115+
c = -a @ normal
116+
117+
mirror = np.sign(c) * (2 * is_mirror - 1)
118+
119+
t = np.sqrt(np.square(1 / r) + np.square(c) - np.square(a.length))
120+
result_expected = r * (a + (c + mirror * t) * normal)
121+
107122
assert isinstance(result, na.AbstractCartesian3dVectorArray)
108123
assert np.allclose(result.length, 1)
109124
if is_mirror:
110125
assert not np.allclose(np.sign(direction @ normal), np.sign(result @ normal))
111126
else:
112127
assert np.allclose(np.sign(direction @ normal), np.sign(result @ normal))
128+
129+
assert np.allclose(result, result_expected)

0 commit comments

Comments
 (0)