|
| 1 | +import jax.numpy as jnp |
| 2 | +from jax import lax, vmap, jit |
| 3 | +from functools import partial |
| 4 | + |
| 5 | +from ._sources import current_density, calculate_charge_density |
| 6 | +from ._boundary_conditions import set_BC_positions, set_BC_particles |
| 7 | +from ._particles import fields_to_particles_grid, boris_step, boris_step_relativistic |
| 8 | +from ._constants import speed_of_light, epsilon_0, elementary_charge, mass_electron, mass_proton |
| 9 | +from ._fields import (field_update, E_from_Gauss_1D_Cartesian, E_from_Gauss_1D_FFT, |
| 10 | + E_from_Poisson_1D_FFT, field_update1, field_update2) |
| 11 | + |
| 12 | +try: import tomllib |
| 13 | +except ModuleNotFoundError: import pip._vendor.tomli as tomllib |
| 14 | + |
| 15 | +__all__ = ['Boris_step', 'CN_step'] |
| 16 | + |
| 17 | +#Boris step |
| 18 | +def Boris_step(carry, step_index, parameters, dx, dt, grid, box_size, |
| 19 | + particle_BC_left, particle_BC_right, |
| 20 | + field_BC_left, field_BC_right, |
| 21 | + field_solver): |
| 22 | + |
| 23 | + (E_field, B_field, positions_minus1_2, positions, |
| 24 | + positions_plus1_2, velocities, qs, ms, q_ms) = carry |
| 25 | + |
| 26 | + J = current_density(positions_minus1_2, positions, positions_plus1_2, velocities, |
| 27 | + qs, dx, dt, grid, grid[0] - dx / 2, particle_BC_left, particle_BC_right) |
| 28 | + E_field, B_field = field_update1(E_field, B_field, dx, dt/2, J, field_BC_left, field_BC_right) |
| 29 | + |
| 30 | + # Add external fields |
| 31 | + total_E = E_field + parameters["external_electric_field"] |
| 32 | + total_B = B_field + parameters["external_magnetic_field"] |
| 33 | + |
| 34 | + # Interpolate fields to particle positions |
| 35 | + def interpolate_fields(x_n): |
| 36 | + E = fields_to_particles_grid(x_n, total_E, dx, grid + dx/2, grid[0], field_BC_left, field_BC_right) |
| 37 | + B = fields_to_particles_grid(x_n, total_B, dx, grid, grid[0] - dx/2, field_BC_left, field_BC_right) |
| 38 | + return E, B |
| 39 | + |
| 40 | + E_field_at_x, B_field_at_x = vmap(interpolate_fields)(positions_plus1_2) |
| 41 | + |
| 42 | + # Particle update: Boris pusher |
| 43 | + positions_plus3_2, velocities_plus1 = lax.cond( |
| 44 | + parameters["relativistic"], |
| 45 | + lambda _: boris_step_relativistic(dt, positions_plus1_2, velocities, qs, ms, E_field_at_x, B_field_at_x), |
| 46 | + lambda _: boris_step(dt, positions_plus1_2, velocities, q_ms, E_field_at_x, B_field_at_x), |
| 47 | + operand=None |
| 48 | + ) |
| 49 | + |
| 50 | + # Apply boundary conditions |
| 51 | + positions_plus3_2, velocities_plus1, qs, ms, q_ms = set_BC_particles( |
| 52 | + positions_plus3_2, velocities_plus1, qs, ms, q_ms, dx, grid, |
| 53 | + *box_size, particle_BC_left, particle_BC_right) |
| 54 | + |
| 55 | + positions_plus1 = set_BC_positions(positions_plus3_2 - (dt / 2) * velocities_plus1, |
| 56 | + qs, dx, grid, *box_size, particle_BC_left, particle_BC_right) |
| 57 | + |
| 58 | + J = current_density(positions_plus1_2, positions_plus1, positions_plus3_2, velocities_plus1, |
| 59 | + qs, dx, dt, grid, grid[0] - dx / 2, particle_BC_left, particle_BC_right) |
| 60 | + E_field, B_field = field_update2(E_field, B_field, dx, dt/2, J, field_BC_left, field_BC_right) |
| 61 | + |
| 62 | + if field_solver != 0: |
| 63 | + charge_density = calculate_charge_density(positions, qs, dx, grid + dx / 2, particle_BC_left, particle_BC_right) |
| 64 | + switcher = { |
| 65 | + 1: E_from_Gauss_1D_FFT, |
| 66 | + 2: E_from_Gauss_1D_Cartesian, |
| 67 | + 3: E_from_Poisson_1D_FFT, |
| 68 | + } |
| 69 | + E_field = E_field.at[:,0].set(switcher[field_solver](charge_density, dx)) |
| 70 | + |
| 71 | + # Update positions and velocities |
| 72 | + positions_minus1_2, positions_plus1_2 = positions_plus1_2, positions_plus3_2 |
| 73 | + velocities = velocities_plus1 |
| 74 | + positions = positions_plus1 |
| 75 | + |
| 76 | + # Prepare state for the next step |
| 77 | + carry = (E_field, B_field, positions_minus1_2, positions, |
| 78 | + positions_plus1_2, velocities, qs, ms, q_ms) |
| 79 | + |
| 80 | + # Collect data for storage |
| 81 | + charge_density = calculate_charge_density(positions, qs, dx, grid, particle_BC_left, particle_BC_right) |
| 82 | + step_data = (positions, velocities, E_field, B_field, J, charge_density) |
| 83 | + |
| 84 | + return carry, step_data |
| 85 | + |
| 86 | +# Implicit Crank-Nicolson step |
| 87 | +@partial(jit, static_argnames=('num_substeps', 'particle_BC_left', 'particle_BC_right', 'field_BC_left', 'field_BC_right')) |
| 88 | +def CN_step(carry, step_index, parameters, dx, dt, grid, box_size, |
| 89 | + particle_BC_left, particle_BC_right, |
| 90 | + field_BC_left, field_BC_right, num_substeps): |
| 91 | + (E_field, B_field, positions, |
| 92 | + velocities, qs, ms, q_ms) = carry |
| 93 | + |
| 94 | + E_new=E_field |
| 95 | + B_new=B_field |
| 96 | + positions_new=positions+ (dt) * velocities |
| 97 | + velocities_new=velocities |
| 98 | + # initialize the array of half-substep positions for Picard iterations |
| 99 | + positions_sub1_2_all_init = jnp.repeat(positions[None, ...], num_substeps, axis=0) |
| 100 | + |
| 101 | + # Picard iteration of solution for next step |
| 102 | + substep_indices = jnp.arange(num_substeps) |
| 103 | + def picard_step(pic_carry, _): |
| 104 | + _, E_new, pos_fix, _, vel_fix, _, qs_prev, ms_prev, q_ms_prev, pos_stag_arr = pic_carry |
| 105 | + E_avg = 0.5 * (E_field + E_new) |
| 106 | + dtau = dt / num_substeps |
| 107 | + |
| 108 | + |
| 109 | + interp_E = partial(fields_to_particles_grid, dx=dx, grid=grid + dx/2, grid_start=grid[0], |
| 110 | + field_BC_left=field_BC_left, field_BC_right=field_BC_right) |
| 111 | + # substepping |
| 112 | + def substep_loop(sub_carry, step_idx): |
| 113 | + pos_sub, vel_sub, qs_sub, ms_sub, q_ms_sub, pos_stag_arr = sub_carry |
| 114 | + pos_stag_prev = pos_stag_arr[step_idx] |
| 115 | + |
| 116 | + E_mid = vmap(interp_E, in_axes=(0, None))(pos_stag_prev, E_avg) |
| 117 | + |
| 118 | + vel_new = vel_sub + (q_ms_sub * E_mid) * dtau |
| 119 | + vel_mid = 0.5 * (vel_sub + vel_new) |
| 120 | + pos_new = pos_sub + vel_mid * dtau |
| 121 | + |
| 122 | + # Apply boundary conditions |
| 123 | + pos_new, vel_mid, qs_new, ms_new, q_ms_new = set_BC_particles( |
| 124 | + pos_new, vel_mid, qs_sub, ms_sub, q_ms_sub, |
| 125 | + dx, grid, *box_size, particle_BC_left, particle_BC_right |
| 126 | + ) |
| 127 | + |
| 128 | + pos_stag_new = set_BC_positions( |
| 129 | + pos_new - 0.5*dtau*vel_mid, |
| 130 | + qs_new, dx, grid, |
| 131 | + *box_size, particle_BC_left, particle_BC_right |
| 132 | + ) |
| 133 | + # Update half substep positions |
| 134 | + pos_stag_arr = pos_stag_arr.at[step_idx].set(pos_stag_new) |
| 135 | + |
| 136 | + # half step current density |
| 137 | + J_sub = current_density( |
| 138 | + pos_sub, pos_stag_new, pos_new, vel_mid, qs_new, dx, dtau, grid, |
| 139 | + grid[0] - dx/2, particle_BC_left, particle_BC_right |
| 140 | + ) |
| 141 | + |
| 142 | + return (pos_new, vel_new, qs_new, ms_new, q_ms_new, pos_stag_arr), J_sub * dtau |
| 143 | + |
| 144 | + # initial substep carry |
| 145 | + sub_init = ( |
| 146 | + pos_fix, vel_fix, |
| 147 | + qs_prev, ms_prev, q_ms_prev, |
| 148 | + pos_stag_arr |
| 149 | + ) |
| 150 | + |
| 151 | + # run through all substeps |
| 152 | + (pos_final, vel_final, qs_final, ms_final, q_ms_final, pos_stag_arr), J_accs = lax.scan( |
| 153 | + substep_loop, |
| 154 | + sub_init, |
| 155 | + substep_indices |
| 156 | + ) |
| 157 | + |
| 158 | + # Sum over substep to get next step current with eletric field |
| 159 | + J_iter = jnp.sum(J_accs, axis=0) / dt |
| 160 | + mean_J = jnp.mean(J_iter, axis=0) |
| 161 | + E_next = E_field - (dt / epsilon_0) * (J_iter - mean_J) |
| 162 | + |
| 163 | + return ( |
| 164 | + (E_new, E_next, pos_fix, pos_final, vel_fix, vel_final, qs_final, ms_final, q_ms_final, pos_stag_arr), |
| 165 | + J_iter |
| 166 | + ) |
| 167 | + |
| 168 | + |
| 169 | + # Picard iteration |
| 170 | + picard_init = (E_new, positions, positions_new, velocities,velocities_new, qs, ms, q_ms, positions_sub1_2_all_init) |
| 171 | + tol = parameters["tolerance_Picard_iterations_implicit_CN"] |
| 172 | + max_iter = parameters["max_number_of_Picard_iterations_implicit_CN"] |
| 173 | + |
| 174 | + positions_sub1_2_all_init = jnp.tile(positions[None, ...], (num_substeps, 1, 1)) |
| 175 | + E_old = E_new |
| 176 | + delta_E0 = jnp.array(jnp.inf) |
| 177 | + iter_idx0 = jnp.array(0) |
| 178 | + |
| 179 | + picard_init = (E_old, E_new, positions, positions_new, velocities, velocities_new, qs, ms, q_ms, positions_sub1_2_all_init) |
| 180 | + state0 = (picard_init, jnp.zeros_like(E_new), delta_E0, iter_idx0) |
| 181 | + |
| 182 | + def cond_fn(state): |
| 183 | + _, _, delta_E, i = state |
| 184 | + return jnp.logical_and(delta_E > tol, i < max_iter) |
| 185 | + |
| 186 | + def body_fn(state): |
| 187 | + carry, _, _, i = state |
| 188 | + |
| 189 | + E_old = carry[0] |
| 190 | + |
| 191 | + new_carry, J_iter = picard_step(carry, None) |
| 192 | + E_next = new_carry[1] |
| 193 | + |
| 194 | + delta_E = jnp.abs(jnp.max(E_next - E_old)) / (jnp.max(jnp.abs(E_next))) |
| 195 | + return (new_carry, J_iter, delta_E, i + 1) |
| 196 | + |
| 197 | + final_carry, J, _, _ = lax.while_loop(cond_fn, body_fn, state0) |
| 198 | + (E_old, E_new, _, positions_new, _, velocities_new, _, _, _, _) = final_carry |
| 199 | + |
| 200 | + # Update carrys for next step |
| 201 | + E_field = E_new |
| 202 | + B_field = B_new |
| 203 | + positions_plus1= positions_new |
| 204 | + velocities_plus1 = velocities_new |
| 205 | + |
| 206 | + charge_density = calculate_charge_density(positions_new, qs, dx, grid, particle_BC_left, particle_BC_right) |
| 207 | + |
| 208 | + carry = (E_field, B_field, positions_plus1, velocities_plus1, qs, ms, q_ms) |
| 209 | + |
| 210 | + # Collect data |
| 211 | + step_data = (positions_plus1, velocities_plus1, E_field, B_field, J, charge_density) |
| 212 | + |
| 213 | + return carry, step_data |
0 commit comments