|
1 | | -from __future__ import annotations |
| 1 | +import math |
2 | 2 | import numpy as np |
| 3 | +import numba as nb |
3 | 4 | import astropy.units as u |
4 | 5 | import named_arrays as na |
5 | 6 |
|
@@ -38,20 +39,17 @@ def snells_law_scalar( |
38 | 39 |
|
39 | 40 |
|
40 | 41 | def snells_law( |
41 | | - wavelength: u.Quantity | na.AbstractScalar, |
42 | 42 | direction: na.AbstractCartesian3dVectorArray, |
43 | 43 | index_refraction: float | na.AbstractScalar, |
44 | 44 | index_refraction_new: float | na.AbstractScalar, |
45 | | - normal: None | na.AbstractCartesian3dVectorArray, |
| 45 | + normal: None | na.AbstractCartesian3dVectorArray = None, |
46 | 46 | is_mirror: bool | na.AbstractScalar = False, |
47 | 47 | ) -> na.Cartesian3dVectorArray: |
48 | 48 | r""" |
49 | 49 | A `vector form of Snell's law <https://en.wikipedia.org/wiki/Snell%27s_law#Vector_form>`_. |
50 | 50 |
|
51 | 51 | Parameters |
52 | 52 | ---------- |
53 | | - wavelength |
54 | | - The wavelength of the incoming light |
55 | 53 | direction |
56 | 54 | The propagation direction of the incoming light |
57 | 55 | index_refraction |
@@ -90,7 +88,6 @@ def snells_law( |
90 | 88 | # Define the keyword arguments that are common |
91 | 89 | # to both the reflected and transmitted ray |
92 | 90 | kwargs = dict( |
93 | | - wavelength=350 * u.nm, |
94 | 91 | direction=direction, |
95 | 92 | index_refraction=1, |
96 | 93 | normal=na.Cartesian3dVectorArray(0, 0, 1), |
@@ -271,22 +268,99 @@ def snells_law( |
271 | 268 | \pm \sqrt{\left( n_2 / n_1 \right)^2 + (\mathbf{k}_1 \cdot \hat{\mathbf{n}})^2 - k_1^2 } |
272 | 269 | \right) \hat{\mathbf{n}} \right]} |
273 | 270 | """ |
274 | | - a = direction |
275 | | - n1 = index_refraction # noqa: F841 |
276 | | - n2 = index_refraction_new # noqa: F841 |
277 | | - |
278 | 271 | if normal is None: |
279 | 272 | normal = na.Cartesian3dVectorArray(0, 0, -1) |
280 | 273 |
|
281 | | - a_x = a.x # noqa: F841 |
282 | | - a_y = a.y # noqa: F841 |
283 | | - a_z = a.z # noqa: F841 |
284 | | - u_x = normal.x # noqa: F841 |
285 | | - u_y = normal.y # noqa: F841 |
286 | | - u_z = normal.z # noqa: F841 |
287 | | - |
288 | | - return na.numexpr.evaluate( |
289 | | - "(n1 / n2) * (a + (-(a_x*u_x + a_y*u_y + a_z*u_z) + sign(-(a_x*u_x + a_y*u_y + a_z*u_z)) " |
290 | | - "* (2 * is_mirror - 1) * sqrt(1 / (n1 / n2)**2 + (a_x*u_x + a_y*u_y + a_z*u_z)**2" |
291 | | - "- (a_x*a_x + a_y*a_y + a_z*a_z))) * normal)" |
| 274 | + direction = direction << u.dimensionless_unscaled |
| 275 | + index_refraction = index_refraction << u.dimensionless_unscaled |
| 276 | + index_refraction_new = index_refraction_new << u.dimensionless_unscaled |
| 277 | + normal = normal << u.dimensionless_unscaled |
| 278 | + |
| 279 | + b_x, b_y, b_z = _snells_law_numba( |
| 280 | + direction.x.value, |
| 281 | + direction.y.value, |
| 282 | + direction.z.value, |
| 283 | + index_refraction.value, |
| 284 | + index_refraction_new.value, |
| 285 | + normal.x.value, |
| 286 | + normal.y.value, |
| 287 | + normal.z.value, |
| 288 | + is_mirror, |
292 | 289 | ) |
| 290 | + |
| 291 | + return na.Cartesian3dVectorArray(b_x, b_y, b_z) |
| 292 | + |
| 293 | + |
| 294 | +@nb.guvectorize( |
| 295 | + [ |
| 296 | + "void(float64,float64,float64,float64,float64,float64,float64,float64,bool,float64[:],float64[:],float64[:])" |
| 297 | + ], |
| 298 | + "(),(),(),(),(),(),(),(),()->(),(),()", |
| 299 | + target="parallel", |
| 300 | + nopython=True, |
| 301 | + cache=True, |
| 302 | +) |
| 303 | +def _snells_law_numba( |
| 304 | + direction_x: float, |
| 305 | + direction_y: float, |
| 306 | + direction_z: float, |
| 307 | + index_refraction: float, |
| 308 | + index_refraction_new: float, |
| 309 | + normal_x: float, |
| 310 | + normal_y: float, |
| 311 | + normal_z: float, |
| 312 | + is_mirror: bool, |
| 313 | + result_x: np.ndarray, |
| 314 | + result_y: np.ndarray, |
| 315 | + result_z: np.ndarray, |
| 316 | +): # pragma: nocover |
| 317 | + """ |
| 318 | + A :mod:`numba`-accelerated version of Snell's law. |
| 319 | +
|
| 320 | + Parameters |
| 321 | + ---------- |
| 322 | + direction_x |
| 323 | + The :math:`x` component of the propagation direction of the incident light. |
| 324 | + direction_y |
| 325 | + The :math:`y` component of the propagation direction of the incident light. |
| 326 | + direction_z |
| 327 | + The :math:`z` component of the propagation direction of the incident light. |
| 328 | + index_refraction |
| 329 | + The index of refraction of the current medium. |
| 330 | + index_refraction_new |
| 331 | + The index of refraction of the new medium. |
| 332 | + normal_x |
| 333 | + The :math:`x` component of the vector perpendicular to the interface. |
| 334 | + normal_y |
| 335 | + The :math:`y` component of the vector perpendicular to the interface. |
| 336 | + normal_z |
| 337 | + The :math:`z` component of the vector perpendicular to the interface. |
| 338 | + is_mirror |
| 339 | + Whether the incident light is reflected or not. |
| 340 | + """ |
| 341 | + a_x = direction_x |
| 342 | + a_y = direction_y |
| 343 | + a_z = direction_z |
| 344 | + |
| 345 | + n1 = index_refraction |
| 346 | + n2 = index_refraction_new |
| 347 | + |
| 348 | + u_x = normal_x |
| 349 | + u_y = normal_y |
| 350 | + u_z = normal_z |
| 351 | + |
| 352 | + a2 = a_x * a_x + a_y * a_y + a_z * a_z |
| 353 | + |
| 354 | + r = n1 / n2 |
| 355 | + r2 = r * r |
| 356 | + |
| 357 | + au = a_x * u_x + a_y * u_y + a_z * u_z |
| 358 | + au2 = au * au |
| 359 | + |
| 360 | + sgn = -math.copysign(1, au) |
| 361 | + |
| 362 | + d = -au + sgn * (2 * is_mirror - 1) * math.sqrt(1 / r2 + au2 - a2) |
| 363 | + |
| 364 | + result_x[:] = r * (a_x + d * u_x) |
| 365 | + result_y[:] = r * (a_y + d * u_y) |
| 366 | + result_z[:] = r * (a_z + d * u_z) |
0 commit comments