Skip to content

Commit 7c8e7fb

Browse files
authored
Merge pull request #10 from uwplasma/lma/enegy_conservation_to_main
Crank Nicolson implicit method from new version
2 parents d055e57 + c2fbd8d commit 7c8e7fb

File tree

10 files changed

+283
-92
lines changed

10 files changed

+283
-92
lines changed

example_input.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ wavenumber_electrons = 8
55
grid_points_per_Debye_length = 2
66
vth_electrons_over_c_x = 0.05
77
ion_temperature_over_electron_temperature = 0.01
8-
timestep_over_spatialstep_times_c = 1.0
8+
timestep_over_spatialstep_times_c = 5.0
99
electron_drift_speed_x = 100000000.0
1010
velocity_plus_minus_electrons_x = true
1111
print_info = true
@@ -15,9 +15,13 @@ electron_charge_over_elementary_charge = -1
1515
ion_charge_over_elementary_charge = 1
1616
ion_mass_over_proton_mass = 1
1717
relativistic = false
18+
tolerance_Picard_iterations_implicit_CN = 1e-5
1819

1920
[solver_parameters]
2021
field_solver = 0
22+
time_evolution_algorithm = 1
2123
number_grid_points = 100
2224
number_pseudoelectrons = 3000
23-
total_steps = 1000
25+
total_steps = 200
26+
max_number_of_Picard_iterations_implicit_CN = 30
27+
number_of_particle_substeps_implicit_CN = 2

example_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
print(f"Run #{i+1}: Wall clock time: {time.time()-start}s")
1717

1818
# Plot the results
19-
plot(output)
19+
plot(output)

examples/Landau_damping.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,26 @@
66
from jax import block_until_ready
77

88
input_parameters = {
9-
"length" : 1e-1, # dimensions of the simulation box in (x, y, z)
10-
"amplitude_perturbation_x" : 1e-1, # amplitude of sinusoidal perturbation in x
11-
"wavenumber_electrons_x": 1, # Wavenumber of sinusoidal electron density perturbation in x (factor of 2pi/length)
12-
"grid_points_per_Debye_length" : 0.3, # dx over Debye length
13-
"velocity_plus_minus_electrons_x": False, # create two groups of electrons moving in opposite directions
14-
"ion_temperature_over_electron_temperature_x": 1e-6, # Temperature of ions over temperature of electrons
9+
"length" : 1e-2, # dimensions of the simulation box in (x, y, z)
10+
"amplitude_perturbation_x" : 4e-3, # amplitude of sinusoidal perturbation in x
11+
"wavenumber_electrons_x": 8, # Wavenumber of sinusoidal electron density perturbation in x (factor of 2pi/length)
12+
"velocity_plus_minus_electrons_x": False, # create two groups of electrons moving in opposite directions
13+
"grid_points_per_Debye_length" : 10, # dx over Debye length
14+
"vth_electrons_over_c_x" : 0.05, # thermal velocity of electrons over speed of light
15+
"ion_temperature_over_electron_temperature_x": 1e-9, # Temperature of ions over temperature of electrons
16+
"timestep_over_spatialstep_times_c": 3, # dt * speed_of_light / dx
1517
"print_info" : True, # print information about the simulation
16-
"ion_mass_over_proton_mass": 1e6, # Ion mass in units of the proton mass
17-
"vth_electrons_over_c_x": 1e-1, # Thermal velocity of electrons over speed of light
18+
"tolerance_Picard_iterations_implicit_CN": 1e-5, # Tolerance for Picard iterations
1819
}
1920

2021
solver_parameters = {
2122
"field_solver" : 0, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT,
22-
"number_grid_points" : 201, # Number of grid points
23-
"number_pseudoelectrons" : 5000, # Number of pseudoelectrons
24-
"total_steps" : 1500, # Total number of time steps
23+
"number_grid_points" : 81, # Number of grid points
24+
"number_pseudoelectrons" : 3000, # Number of pseudoelectrons
25+
"total_steps" : 200, # Total number of time steps
26+
"time_evolution_algorithm": 1, # Algorithm to evolve particles in time - 0: Boris, 1: Implicit_Crank Nicholson
27+
"max_number_of_Picard_iterations_implicit_CN": 30, # Maximum number of iterations for Picard iteration converging
28+
"number_of_particle_substeps_implicit_CN": 2, # The number of substep for one time eletric field update
2529
}
2630

2731
output = block_until_ready(simulation(input_parameters, **solver_parameters))

examples/Weibel_instability.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
"field_solver" : 0, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT,
2929
"number_grid_points" : 201, # Number of grid points
3030
"number_pseudoelectrons" : 3000, # Number of pseudoelectrons
31-
"total_steps" : 8000, # Total number of time steps
31+
"total_steps" : 6000, # Total number of time steps
32+
"time_evolution_algorithm": 0, # Algorithm to evolve particles in time - 0: Boris, 1: Implicit_Crank Nicholson (electrostatic)
3233
}
3334

3435
output = block_until_ready(simulation(input_parameters, **solver_parameters))

examples/input.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,20 @@ wavenumber_ions_x = 0
66
grid_points_per_Debye_length = 2.0
77
vth_electrons_over_c_x = 0.05
88
ion_temperature_over_electron_temperature_x = 0.01
9-
timestep_over_spatialstep_times_c = 1.0
9+
timestep_over_spatialstep_times_c = 4.0
1010
electron_drift_speed_x = 150000000.0
1111
velocity_plus_minus_electrons_x = true
1212
print_info = true
1313
external_electric_field_amplitude = 0
1414
external_electric_field_wavenumber = 0
15-
relativistic = false
15+
relativistic = true
16+
tolerance_Picard_iterations_implicit_CN = 1e-5
1617

1718
[solver_parameters]
19+
time_evolution_algorithm = 1
20+
max_number_of_Picard_iterations_implicit_CN = 30
21+
number_of_particle_substeps_implicit_CN = 2
1822
field_solver = 0
19-
number_grid_points = 81
23+
number_grid_points = 61
2024
number_pseudoelectrons = 3000
21-
total_steps = 1000
25+
total_steps = 700

examples/optimize_two_stream_saturation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def objective_function(Ti):
5959
############### -------- #################
6060
print(f'Perform a simple optimization with {max_iterations_optimization} iterations')
6161
## Using Least Squares
62-
res = least_squares(objective_function, x0=x0_optimization, jac=learning_rate, verbose=2, max_nfev=max_iterations_optimization)
62+
res = least_squares(objective_function, x0=x0_optimization, verbose=2, max_nfev=max_iterations_optimization)
6363
optimized_Ti = res.x[0]
6464
## Using OPTAX
6565
# import optax

jaxincell/_algorithms.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import jax.numpy as jnp
2+
from jax import lax, vmap, jit
3+
from functools import partial
4+
5+
from ._sources import current_density, calculate_charge_density
6+
from ._boundary_conditions import set_BC_positions, set_BC_particles
7+
from ._particles import fields_to_particles_grid, boris_step, boris_step_relativistic
8+
from ._constants import speed_of_light, epsilon_0, elementary_charge, mass_electron, mass_proton
9+
from ._fields import (field_update, E_from_Gauss_1D_Cartesian, E_from_Gauss_1D_FFT,
10+
E_from_Poisson_1D_FFT, field_update1, field_update2)
11+
12+
try: import tomllib
13+
except ModuleNotFoundError: import pip._vendor.tomli as tomllib
14+
15+
__all__ = ['Boris_step', 'CN_step']
16+
17+
#Boris step
18+
def Boris_step(carry, step_index, parameters, dx, dt, grid, box_size,
19+
particle_BC_left, particle_BC_right,
20+
field_BC_left, field_BC_right,
21+
field_solver):
22+
23+
(E_field, B_field, positions_minus1_2, positions,
24+
positions_plus1_2, velocities, qs, ms, q_ms) = carry
25+
26+
J = current_density(positions_minus1_2, positions, positions_plus1_2, velocities,
27+
qs, dx, dt, grid, grid[0] - dx / 2, particle_BC_left, particle_BC_right)
28+
E_field, B_field = field_update1(E_field, B_field, dx, dt/2, J, field_BC_left, field_BC_right)
29+
30+
# Add external fields
31+
total_E = E_field + parameters["external_electric_field"]
32+
total_B = B_field + parameters["external_magnetic_field"]
33+
34+
# Interpolate fields to particle positions
35+
def interpolate_fields(x_n):
36+
E = fields_to_particles_grid(x_n, total_E, dx, grid + dx/2, grid[0], field_BC_left, field_BC_right)
37+
B = fields_to_particles_grid(x_n, total_B, dx, grid, grid[0] - dx/2, field_BC_left, field_BC_right)
38+
return E, B
39+
40+
E_field_at_x, B_field_at_x = vmap(interpolate_fields)(positions_plus1_2)
41+
42+
# Particle update: Boris pusher
43+
positions_plus3_2, velocities_plus1 = lax.cond(
44+
parameters["relativistic"],
45+
lambda _: boris_step_relativistic(dt, positions_plus1_2, velocities, qs, ms, E_field_at_x, B_field_at_x),
46+
lambda _: boris_step(dt, positions_plus1_2, velocities, q_ms, E_field_at_x, B_field_at_x),
47+
operand=None
48+
)
49+
50+
# Apply boundary conditions
51+
positions_plus3_2, velocities_plus1, qs, ms, q_ms = set_BC_particles(
52+
positions_plus3_2, velocities_plus1, qs, ms, q_ms, dx, grid,
53+
*box_size, particle_BC_left, particle_BC_right)
54+
55+
positions_plus1 = set_BC_positions(positions_plus3_2 - (dt / 2) * velocities_plus1,
56+
qs, dx, grid, *box_size, particle_BC_left, particle_BC_right)
57+
58+
J = current_density(positions_plus1_2, positions_plus1, positions_plus3_2, velocities_plus1,
59+
qs, dx, dt, grid, grid[0] - dx / 2, particle_BC_left, particle_BC_right)
60+
E_field, B_field = field_update2(E_field, B_field, dx, dt/2, J, field_BC_left, field_BC_right)
61+
62+
if field_solver != 0:
63+
charge_density = calculate_charge_density(positions, qs, dx, grid + dx / 2, particle_BC_left, particle_BC_right)
64+
switcher = {
65+
1: E_from_Gauss_1D_FFT,
66+
2: E_from_Gauss_1D_Cartesian,
67+
3: E_from_Poisson_1D_FFT,
68+
}
69+
E_field = E_field.at[:,0].set(switcher[field_solver](charge_density, dx))
70+
71+
# Update positions and velocities
72+
positions_minus1_2, positions_plus1_2 = positions_plus1_2, positions_plus3_2
73+
velocities = velocities_plus1
74+
positions = positions_plus1
75+
76+
# Prepare state for the next step
77+
carry = (E_field, B_field, positions_minus1_2, positions,
78+
positions_plus1_2, velocities, qs, ms, q_ms)
79+
80+
# Collect data for storage
81+
charge_density = calculate_charge_density(positions, qs, dx, grid, particle_BC_left, particle_BC_right)
82+
step_data = (positions, velocities, E_field, B_field, J, charge_density)
83+
84+
return carry, step_data
85+
86+
# Implicit Crank-Nicolson step
87+
@partial(jit, static_argnames=('num_substeps', 'particle_BC_left', 'particle_BC_right', 'field_BC_left', 'field_BC_right'))
88+
def CN_step(carry, step_index, parameters, dx, dt, grid, box_size,
89+
particle_BC_left, particle_BC_right,
90+
field_BC_left, field_BC_right, num_substeps):
91+
(E_field, B_field, positions,
92+
velocities, qs, ms, q_ms) = carry
93+
94+
E_new=E_field
95+
B_new=B_field
96+
positions_new=positions+ (dt) * velocities
97+
velocities_new=velocities
98+
# initialize the array of half-substep positions for Picard iterations
99+
positions_sub1_2_all_init = jnp.repeat(positions[None, ...], num_substeps, axis=0)
100+
101+
# Picard iteration of solution for next step
102+
substep_indices = jnp.arange(num_substeps)
103+
def picard_step(pic_carry, _):
104+
_, E_new, pos_fix, _, vel_fix, _, qs_prev, ms_prev, q_ms_prev, pos_stag_arr = pic_carry
105+
E_avg = 0.5 * (E_field + E_new)
106+
dtau = dt / num_substeps
107+
108+
109+
interp_E = partial(fields_to_particles_grid, dx=dx, grid=grid + dx/2, grid_start=grid[0],
110+
field_BC_left=field_BC_left, field_BC_right=field_BC_right)
111+
# substepping
112+
def substep_loop(sub_carry, step_idx):
113+
pos_sub, vel_sub, qs_sub, ms_sub, q_ms_sub, pos_stag_arr = sub_carry
114+
pos_stag_prev = pos_stag_arr[step_idx]
115+
116+
E_mid = vmap(interp_E, in_axes=(0, None))(pos_stag_prev, E_avg)
117+
118+
vel_new = vel_sub + (q_ms_sub * E_mid) * dtau
119+
vel_mid = 0.5 * (vel_sub + vel_new)
120+
pos_new = pos_sub + vel_mid * dtau
121+
122+
# Apply boundary conditions
123+
pos_new, vel_mid, qs_new, ms_new, q_ms_new = set_BC_particles(
124+
pos_new, vel_mid, qs_sub, ms_sub, q_ms_sub,
125+
dx, grid, *box_size, particle_BC_left, particle_BC_right
126+
)
127+
128+
pos_stag_new = set_BC_positions(
129+
pos_new - 0.5*dtau*vel_mid,
130+
qs_new, dx, grid,
131+
*box_size, particle_BC_left, particle_BC_right
132+
)
133+
# Update half substep positions
134+
pos_stag_arr = pos_stag_arr.at[step_idx].set(pos_stag_new)
135+
136+
# half step current density
137+
J_sub = current_density(
138+
pos_sub, pos_stag_new, pos_new, vel_mid, qs_new, dx, dtau, grid,
139+
grid[0] - dx/2, particle_BC_left, particle_BC_right
140+
)
141+
142+
return (pos_new, vel_new, qs_new, ms_new, q_ms_new, pos_stag_arr), J_sub * dtau
143+
144+
# initial substep carry
145+
sub_init = (
146+
pos_fix, vel_fix,
147+
qs_prev, ms_prev, q_ms_prev,
148+
pos_stag_arr
149+
)
150+
151+
# run through all substeps
152+
(pos_final, vel_final, qs_final, ms_final, q_ms_final, pos_stag_arr), J_accs = lax.scan(
153+
substep_loop,
154+
sub_init,
155+
substep_indices
156+
)
157+
158+
# Sum over substep to get next step current with eletric field
159+
J_iter = jnp.sum(J_accs, axis=0) / dt
160+
mean_J = jnp.mean(J_iter, axis=0)
161+
E_next = E_field - (dt / epsilon_0) * (J_iter - mean_J)
162+
163+
return (
164+
(E_new, E_next, pos_fix, pos_final, vel_fix, vel_final, qs_final, ms_final, q_ms_final, pos_stag_arr),
165+
J_iter
166+
)
167+
168+
169+
# Picard iteration
170+
picard_init = (E_new, positions, positions_new, velocities,velocities_new, qs, ms, q_ms, positions_sub1_2_all_init)
171+
tol = parameters["tolerance_Picard_iterations_implicit_CN"]
172+
max_iter = parameters["max_number_of_Picard_iterations_implicit_CN"]
173+
174+
positions_sub1_2_all_init = jnp.tile(positions[None, ...], (num_substeps, 1, 1))
175+
E_old = E_new
176+
delta_E0 = jnp.array(jnp.inf)
177+
iter_idx0 = jnp.array(0)
178+
179+
picard_init = (E_old, E_new, positions, positions_new, velocities, velocities_new, qs, ms, q_ms, positions_sub1_2_all_init)
180+
state0 = (picard_init, jnp.zeros_like(E_new), delta_E0, iter_idx0)
181+
182+
def cond_fn(state):
183+
_, _, delta_E, i = state
184+
return jnp.logical_and(delta_E > tol, i < max_iter)
185+
186+
def body_fn(state):
187+
carry, _, _, i = state
188+
189+
E_old = carry[0]
190+
191+
new_carry, J_iter = picard_step(carry, None)
192+
E_next = new_carry[1]
193+
194+
delta_E = jnp.abs(jnp.max(E_next - E_old)) / (jnp.max(jnp.abs(E_next)))
195+
return (new_carry, J_iter, delta_E, i + 1)
196+
197+
final_carry, J, _, _ = lax.while_loop(cond_fn, body_fn, state0)
198+
(E_old, E_new, _, positions_new, _, velocities_new, _, _, _, _) = final_carry
199+
200+
# Update carrys for next step
201+
E_field = E_new
202+
B_field = B_new
203+
positions_plus1= positions_new
204+
velocities_plus1 = velocities_new
205+
206+
charge_density = calculate_charge_density(positions_new, qs, dx, grid, particle_BC_left, particle_BC_right)
207+
208+
carry = (E_field, B_field, positions_plus1, velocities_plus1, qs, ms, q_ms)
209+
210+
# Collect data
211+
step_data = (positions_plus1, velocities_plus1, E_field, B_field, J, charge_density)
212+
213+
return carry, step_data

jaxincell/_diagnostics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def diagnostics(output):
2525
dominant_frequency = jnp.abs(freqs[peak_index])
2626

2727
def integrate(y, dx): return 0.5 * (jnp.asarray(dx) * (y[..., 1:] + y[..., :-1])).sum(-1)
28+
# def integrate(y, dx): return jnp.sum(y, axis=-1) * dx
2829

2930
abs_E_squared = jnp.sum(output['electric_field']**2, axis=-1)
3031
abs_externalE_squared = jnp.sum(output['external_electric_field']**2, axis=-1)

jaxincell/_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label):
160160
if jnp.max(output["magnetic_field_energy"]) > 1e-10:
161161
energy_ax.plot(time, output["magnetic_field_energy"], label="Magnetic field energy")
162162
energy_ax.plot(time[2:], jnp.abs(jnp.mean(output["charge_density"][2:], axis=-1))*1e15, label=r"Mean $\rho \times 10^{15}$")
163-
energy_ax.plot(time[2:], jnp.abs(output["total_energy"][2:] - output["total_energy"][2]) / output["total_energy"][2], label="Relative energy error")
163+
energy_ax.plot(time[1:], jnp.abs(output["total_energy"][1:] - output["total_energy"][0]) / output["total_energy"][0], label="Relative energy error")
164164
energy_ax.set(title="Energy", xlabel=r"Time ($\omega_{pe}^{-1}$)",
165165
ylabel="Energy (J)", yscale="log", ylim=[1e-7, None])
166166
energy_ax.legend(fontsize=7)

0 commit comments

Comments
 (0)