Skip to content

Commit c75ede2

Browse files
committed
Add functionality to change spacecharge PIC mesh at runtime
1 parent b02664d commit c75ede2

File tree

3 files changed

+302
-3
lines changed

3 files changed

+302
-3
lines changed

xfields/beam_elements/spacecharge.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88
from xfields import BiGaussianFieldMap, mean_and_std
9-
from xfields import TriLinearInterpolatedFieldMap
9+
from ..fieldmaps.interpolated import TriLinearInterpolatedFieldMap
1010
from ..longitudinal_profiles import LongitudinalProfileQGaussian
1111
from ..fieldmaps import BiGaussianFieldMap
1212
from ..general import _pkg_root
@@ -182,6 +182,48 @@ def track(self, particles):
182182
# call C tracking kernel
183183
super().track(particles)
184184

185+
def set_xy_mesh(self, x_range, y_range, *, zero_fields=True):
186+
"""
187+
Retile the underlying fieldmap in x/y while
188+
keeping the grid counts unchanged.
189+
190+
Args:
191+
x_range (tuple[float, float]): New ``(xmin, xmax)`` in
192+
meters.
193+
y_range (tuple[float, float]): New ``(ymin, ymax)`` in
194+
meters.
195+
zero_fields (bool): If ``True``, zero stored
196+
``rho``, ``phi`` and their derivatives after retiling.
197+
"""
198+
(xmin, xmax) = map(float, x_range)
199+
(ymin, ymax) = map(float, y_range)
200+
self.fieldmap.retile_xy(
201+
xmin, xmax, ymin, ymax, zero_fields=zero_fields
202+
)
203+
204+
def set_xyz_mesh(self, *, x_range, y_range, z_range,
205+
zero_fields=True):
206+
"""
207+
Retile the underlying fieldmap in x/y/z while
208+
keeping the grid counts unchanged.
209+
210+
Args:
211+
x_range (tuple[float, float]): New ``(xmin, xmax)`` in
212+
meters.
213+
y_range (tuple[float, float]): New ``(ymin, ymax)`` in
214+
meters.
215+
z_range (tuple[float, float]): New ``(zmin, zmax)`` in
216+
meters.
217+
zero_fields (bool): If ``True``, zero stored
218+
``rho``, ``phi``, and their derivatives after retiling.
219+
"""
220+
(xmin, xmax) = map(float, x_range)
221+
(ymin, ymax) = map(float, y_range)
222+
(zmin, zmax) = map(float, z_range)
223+
self.fieldmap.retile_xyz(
224+
xmin, xmax, ymin, ymax, zmin, zmax, zero_fields=zero_fields
225+
)
226+
185227
class SpaceChargeBiGaussian(xt.BeamElement):
186228

187229
_xofields = {

xfields/fieldmaps/interpolated.py

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,116 @@ def generate_solver(self, solver, fftplan):
543543

544544
return solver
545545

546+
def retile_xy(self, xmin, xmax, ymin, ymax, *, zero_fields=True):
547+
"""
548+
Retile the x/y mesh in-place while keeping the
549+
number of cells unchanged.
550+
551+
Args:
552+
xmin (float): New minimum x coordinate in meters.
553+
xmax (float): New maximum x coordinate in meters.
554+
ymin (float): New minimum y coordinate in meters.
555+
ymax (float): New maximum y coordinate in meters.
556+
zero_fields (bool): If ``True``, zero stored
557+
``rho``, ``phi`` and their derivatives after retiling.
558+
"""
559+
nx, ny = len(self._x_grid), len(self._y_grid)
560+
561+
self._x_grid = np.linspace(float(xmin), float(xmax), nx,
562+
dtype=np.float64)
563+
self._y_grid = np.linspace(float(ymin), float(ymax), ny,
564+
dtype=np.float64)
565+
566+
if zero_fields:
567+
self._rho[...] = 0.0
568+
self._phi[...] = 0.0
569+
self._dphi_dx[...] = 0.0
570+
self._dphi_dy[...] = 0.0
571+
self._dphi_dz[...] = 0.0
572+
573+
self._dx = (xmax - xmin) / (nx - 1)
574+
self._dy = (ymax - ymin) / (ny - 1)
575+
self._x_min = float(xmin)
576+
self._y_min = float(ymin)
577+
578+
# Update derived volume quantities (dz unchanged)
579+
self._cell_volume = self._dx * self._dy * self._dz
580+
self._inv_cell_volume = 1.0 / self._cell_volume
581+
582+
# Refresh the solver geometry
583+
scale_dx, scale_dy, scale_dz = self.scale_coordinates_in_solver
584+
self.solver.refresh_geometry(
585+
self._x_grid * scale_dx,
586+
self._y_grid * scale_dy,
587+
self._z_grid * scale_dz)
588+
589+
def retile_xyz(
590+
self,
591+
xmin,
592+
xmax,
593+
ymin,
594+
ymax,
595+
zmin,
596+
zmax,
597+
*,
598+
zero_fields=True
599+
):
600+
"""
601+
Retile the full x/y/z mesh in-place while
602+
keeping the number of cells unchanged.
603+
604+
Args:
605+
xmin (float): New minimum x coordinate in meters.
606+
xmax (float): New maximum x coordinate in meters.
607+
ymin (float): New minimum y coordinate in meters.
608+
ymax (float): New maximum y coordinate in meters.
609+
zmin (float): New minimum z coordinate in meters.
610+
zmax (float): New maximum z coordinate in meters.
611+
zero_fields (bool): If ``True``, zero stored
612+
``rho``, ``phi`` and their derivatives after retiling.
613+
"""
614+
nx, ny, nz = self._x_grid.size, self._y_grid.size, self._z_grid.size
615+
616+
self._x_grid = np.linspace(
617+
float(xmin),
618+
float(xmax),
619+
nx,
620+
dtype=np.float64)
621+
self._y_grid = np.linspace(
622+
float(ymin),
623+
float(ymax),
624+
ny,
625+
dtype=np.float64)
626+
self._z_grid = np.linspace(
627+
float(zmin),
628+
float(zmax),
629+
nz,
630+
dtype=np.float64)
631+
632+
if zero_fields:
633+
self._rho[...] = 0.0
634+
self._phi[...] = 0.0
635+
self._dphi_dx[...] = 0.0
636+
self._dphi_dy[...] = 0.0
637+
self._dphi_dz[...] = 0.0
638+
639+
self._dx = (xmax - xmin) / (nx - 1)
640+
self._dy = (ymax - ymin) / (ny - 1)
641+
self._dz = (zmax - zmin) / (nz - 1)
642+
self._x_min = float(xmin)
643+
self._y_min = float(ymin)
644+
self._z_min = float(zmin)
645+
646+
self._cell_volume = self._dx * self._dy * self._dz
647+
self._inv_cell_volume = 1.0 / self._cell_volume
648+
649+
# Refresh the solver geometry
650+
scale_dx, scale_dy, scale_dz = self.scale_coordinates_in_solver
651+
self.solver.refresh_geometry(
652+
self._x_grid * scale_dx,
653+
self._y_grid * scale_dy,
654+
self._z_grid * scale_dz)
655+
546656
@property
547657
def x_grid(self):
548658
"""
@@ -564,6 +674,27 @@ def z_grid(self):
564674
"""
565675
return self._z_grid
566676

677+
@property
678+
def x_range(self):
679+
"""
680+
Horizontal range.
681+
"""
682+
return (float(self._x_grid[0]), float(self._x_grid[-1]))
683+
684+
@property
685+
def y_range(self):
686+
"""
687+
Vertical range.
688+
"""
689+
return (float(self._y_grid[0]), float(self._y_grid[-1]))
690+
691+
@property
692+
def z_range(self):
693+
"""
694+
Longitudinal range.
695+
"""
696+
return (float(self._z_grid[0]), float(self._z_grid[-1]))
697+
567698
@property
568699
def nx(self):
569700
"""
@@ -665,5 +796,3 @@ def _configure_grid(vname, v_grid, dv, v_range, nv):
665796
v_grid = np.linspace(v_range[0], v_range[1], nv)
666797

667798
return v_grid
668-
669-

xfields/solvers/fftsolvers.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,83 @@ def __init__(self, dx, dy, dz, nx, ny, nz, context=None, fftplan=None):
101101
self._gint_rep_transf_dev = gint_rep_dev
102102
self.fftplan = fftplan
103103

104+
def refresh_geometry(self, x_grid, y_grid, z_grid):
105+
"""
106+
Rebuild the 3D Green's function used by the solver.
107+
108+
The number of cells in each direction must match the original
109+
solver.
110+
111+
Args:
112+
x_grid (array-like): Monotonic array of x cell centers in
113+
meters.
114+
y_grid (array-like): Monotonic array of y cell centers in
115+
meters.
116+
z_grid (array-like): Monotonic array of z cell centers in
117+
meters.
118+
"""
119+
nx, ny, nz = len(x_grid), len(y_grid), len(z_grid)
120+
assert (nx, ny, nz) == (self.nx, self.ny, self.nz), (
121+
"refresh_geometry requires unchanged (nx, ny, nz)"
122+
)
123+
124+
dx = float(x_grid[1] - x_grid[0])
125+
dy = float(y_grid[1] - y_grid[0])
126+
dz = float(z_grid[1] - z_grid[0])
127+
128+
xg_F = np.arange(0, nx + 2) * dx - dx / 2.0
129+
yg_F = np.arange(0, ny + 2) * dy - dy / 2.0
130+
zg_F = np.arange(0, nz + 2) * dz - dz / 2.0
131+
XX_F, YY_F, ZZ_F = np.meshgrid(xg_F, yg_F, zg_F, indexing='ij')
132+
133+
F_temp = primitive_func_3d(XX_F, YY_F, ZZ_F)
134+
135+
gint_rep = np.zeros(
136+
(2 * nx, 2 * ny, 2 * nz),
137+
dtype=np.complex128,
138+
order='F')
139+
140+
gint_rep[:nx + 1, :ny + 1, :nz + 1] = (
141+
F_temp[1:, 1:, 1:]
142+
- F_temp[:-1, 1:, 1:]
143+
- F_temp[1:, :-1, 1:]
144+
+ F_temp[:-1, :-1, 1:]
145+
- F_temp[1:, 1:, :-1]
146+
+ F_temp[:-1, 1:, :-1]
147+
+ F_temp[1:, :-1, :-1]
148+
- F_temp[:-1, :-1, :-1]
149+
)
150+
151+
gint_rep[nx + 1:, :ny + 1, :nz + 1] = gint_rep[nx - 1:0:-1,
152+
:ny + 1,
153+
:nz + 1]
154+
gint_rep[:nx + 1, ny + 1:, :nz + 1] = gint_rep[:nx + 1,
155+
ny - 1:0:-1,
156+
:nz + 1]
157+
gint_rep[nx + 1:, ny + 1:, :nz + 1] = gint_rep[nx - 1:0:-1,
158+
ny - 1:0:-1,
159+
:nz + 1]
160+
gint_rep[:nx + 1, :ny + 1, nz + 1:] = gint_rep[:nx + 1,
161+
:ny + 1,
162+
nz - 1:0:-1]
163+
gint_rep[nx + 1:, :ny + 1, nz + 1:] = gint_rep[nx - 1:0:-1,
164+
:ny + 1,
165+
nz - 1:0:-1]
166+
gint_rep[:nx + 1, ny + 1:, nz + 1:] = gint_rep[:nx + 1,
167+
ny - 1:0:-1,
168+
nz - 1:0:-1]
169+
gint_rep[nx + 1:, ny + 1:, nz + 1:] = gint_rep[nx - 1:0:-1,
170+
ny - 1:0:-1,
171+
nz - 1:0:-1]
172+
173+
gint_rep_dev = self.context.nparray_to_context_array(gint_rep)
174+
self.fftplan.transform(gint_rep_dev)
175+
176+
self.dx = dx
177+
self.dy = dy
178+
self.dz = dz
179+
self._gint_rep_transf_dev = gint_rep_dev
180+
104181
#@profile
105182
def solve(self, rho):
106183

@@ -213,6 +290,57 @@ def __init__(self, dx, dy, dz, nx, ny, nz, context=None, fftplan=None):
213290
self._gint_rep_transf_dev = gint_rep_transf_dev
214291
self.fftplan = fftplan
215292

293+
def refresh_geometry(self, x_grid, y_grid, z_grid):
294+
"""
295+
Rebuild the 2D Green's function used by the solver.
296+
297+
The number of cells in each direction must match the original
298+
solver.
299+
300+
Args:
301+
x_grid (array-like): Monotonic array of x cell centers in
302+
meters.
303+
y_grid (array-like): Monotonic array of y cell centers in
304+
meters.
305+
z_grid (array-like): Monotonic array of z cell centers in
306+
meters.
307+
"""
308+
nx, ny, nz = len(x_grid), len(y_grid), len(z_grid)
309+
assert (nx, ny, nz) == (self.nx, self.ny, self.nz), (
310+
"refresh_geometry requires unchanged (nx, ny, nz)"
311+
)
312+
313+
dx = float(x_grid[1] - x_grid[0])
314+
dy = float(y_grid[1] - y_grid[0])
315+
dz = float(z_grid[1] - z_grid[0])
316+
317+
xg_F = np.arange(0, nx + 2) * dx - dx / 2.0
318+
yg_F = np.arange(0, ny + 2) * dy - dy / 2.0
319+
XX_F, YY_F = np.meshgrid(xg_F, yg_F, indexing='ij')
320+
321+
F_temp = primitive_func_2p5d(XX_F, YY_F)
322+
323+
gint_rep = np.zeros((2 * nx, 2 * ny), dtype=np.complex128, order='F')
324+
gint_rep[:nx + 1, :ny + 1] = (
325+
F_temp[1:, 1:]
326+
- F_temp[:-1, 1:]
327+
- F_temp[1:, :-1]
328+
+ F_temp[:-1, :-1]
329+
)
330+
gint_rep[nx + 1:, :ny + 1] = gint_rep[nx - 1:0:-1, :ny + 1]
331+
gint_rep[:nx + 1, ny + 1:] = gint_rep[:nx + 1, ny - 1:0:-1]
332+
gint_rep[nx + 1:, ny + 1:] = gint_rep[nx - 1:0:-1, ny - 1:0:-1]
333+
334+
gint_rep_transf = np.fft.fftn(gint_rep, axes=(0, 1))
335+
gint_rep_transf_dev = self.context.nparray_to_context_array(
336+
np.atleast_3d(gint_rep_transf)
337+
)
338+
339+
self.dx = dx
340+
self.dy = dy
341+
self.dz = dz
342+
self._gint_rep_transf_dev = gint_rep_transf_dev
343+
216344
class FFTSolver2p5DAveraged(Solver):
217345

218346
def __init__(self, dx, dy, dz, nx, ny, nz, context=None, fftplan=None):

0 commit comments

Comments
 (0)