Skip to content

Commit 7c17c6a

Browse files
committed
Changing Coriolis for loop to vmap,
get rid of jnp.array when possible, benchmark and tests on eps.
1 parent 6dd50a4 commit 7c17c6a

File tree

6 files changed

+557
-224
lines changed

6 files changed

+557
-224
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,5 @@ dmypy.json
135135

136136
# DS_STORE
137137
.DS_Store
138+
*.mp4
139+
*.gif

src/jsrm/systems/planar_pcs.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -152,32 +152,35 @@ def apply_eps_to_bend_strains(xi: Array, _eps: float) -> Array:
152152
"""
153153
Add a small number to the bending strain to avoid singularities
154154
"""
155-
xi_reshaped = xi.reshape((-1, 3))
156-
157-
xi_bend_sign = jnp.sign(xi_reshaped[:, 0])
158-
# set zero sign to 1 (i.e. positive)
159-
xi_bend_sign = jnp.where(xi_bend_sign == 0, 1, xi_bend_sign)
160-
# add eps to the bending strain (i.e. the first column)
161-
sigma_b_epsed = lax.select(
162-
jnp.abs(xi_reshaped[:, 0]) < _eps,
163-
xi_bend_sign * _eps,
164-
xi_reshaped[:, 0],
165-
)
166-
xi_epsed = jnp.stack(
167-
[
168-
sigma_b_epsed,
169-
xi_reshaped[:, 1],
170-
xi_reshaped[:, 2],
171-
],
172-
axis=1,
173-
)
174-
175-
# old implementation:
176-
# xi_epsed = xi_reshaped
177-
# xi_epsed = xi_epsed.at[:, 0].add(xi_bend_sign * _eps)
178-
179-
# flatten the array
180-
xi_epsed = xi_epsed.flatten()
155+
if _eps == None:
156+
return xi
157+
else:
158+
xi_reshaped = xi.reshape((-1, 3))
159+
160+
xi_bend_sign = jnp.sign(xi_reshaped[:, 0])
161+
# set zero sign to 1 (i.e. positive)
162+
xi_bend_sign = jnp.where(xi_bend_sign == 0, 1, xi_bend_sign)
163+
# add eps to the bending strain (i.e. the first column)
164+
sigma_b_epsed = lax.select(
165+
jnp.abs(xi_reshaped[:, 0]) < _eps,
166+
xi_bend_sign * _eps,
167+
xi_reshaped[:, 0],
168+
)
169+
xi_epsed = jnp.stack(
170+
[
171+
sigma_b_epsed,
172+
xi_reshaped[:, 1],
173+
xi_reshaped[:, 2],
174+
],
175+
axis=1,
176+
)
177+
178+
# old implementation:
179+
# xi_epsed = xi_reshaped
180+
# xi_epsed = xi_epsed.at[:, 0].add(xi_bend_sign * _eps)
181+
182+
# flatten the array
183+
xi_epsed = xi_epsed.flatten()
181184

182185
return xi_epsed
183186

0 commit comments

Comments
 (0)