@@ -152,32 +152,35 @@ def apply_eps_to_bend_strains(xi: Array, _eps: float) -> Array:
152
152
"""
153
153
Add a small number to the bending strain to avoid singularities
154
154
"""
155
- xi_reshaped = xi .reshape ((- 1 , 3 ))
156
-
157
- xi_bend_sign = jnp .sign (xi_reshaped [:, 0 ])
158
- # set zero sign to 1 (i.e. positive)
159
- xi_bend_sign = jnp .where (xi_bend_sign == 0 , 1 , xi_bend_sign )
160
- # add eps to the bending strain (i.e. the first column)
161
- sigma_b_epsed = lax .select (
162
- jnp .abs (xi_reshaped [:, 0 ]) < _eps ,
163
- xi_bend_sign * _eps ,
164
- xi_reshaped [:, 0 ],
165
- )
166
- xi_epsed = jnp .stack (
167
- [
168
- sigma_b_epsed ,
169
- xi_reshaped [:, 1 ],
170
- xi_reshaped [:, 2 ],
171
- ],
172
- axis = 1 ,
173
- )
174
-
175
- # old implementation:
176
- # xi_epsed = xi_reshaped
177
- # xi_epsed = xi_epsed.at[:, 0].add(xi_bend_sign * _eps)
178
-
179
- # flatten the array
180
- xi_epsed = xi_epsed .flatten ()
155
+ if _eps == None :
156
+ return xi
157
+ else :
158
+ xi_reshaped = xi .reshape ((- 1 , 3 ))
159
+
160
+ xi_bend_sign = jnp .sign (xi_reshaped [:, 0 ])
161
+ # set zero sign to 1 (i.e. positive)
162
+ xi_bend_sign = jnp .where (xi_bend_sign == 0 , 1 , xi_bend_sign )
163
+ # add eps to the bending strain (i.e. the first column)
164
+ sigma_b_epsed = lax .select (
165
+ jnp .abs (xi_reshaped [:, 0 ]) < _eps ,
166
+ xi_bend_sign * _eps ,
167
+ xi_reshaped [:, 0 ],
168
+ )
169
+ xi_epsed = jnp .stack (
170
+ [
171
+ sigma_b_epsed ,
172
+ xi_reshaped [:, 1 ],
173
+ xi_reshaped [:, 2 ],
174
+ ],
175
+ axis = 1 ,
176
+ )
177
+
178
+ # old implementation:
179
+ # xi_epsed = xi_reshaped
180
+ # xi_epsed = xi_epsed.at[:, 0].add(xi_bend_sign * _eps)
181
+
182
+ # flatten the array
183
+ xi_epsed = xi_epsed .flatten ()
181
184
182
185
return xi_epsed
183
186
0 commit comments