Skip to content

Commit 9042bc9

Browse files
Don't keep rebuilding cross sections, default num points set for model
1 parent 1d7de1c commit 9042bc9

File tree

5 files changed

+167
-14
lines changed

5 files changed

+167
-14
lines changed
Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111
b1 = lhc_with_metadata['b1']
1212
lhc_length = b1.get_length()
1313

14-
aperture_model = Aperture.from_line_with_madx_metadata(b1, line_name='b1', context=context)
14+
aperture_model = Aperture.from_line_with_madx_metadata(b1, context=context)
1515

1616
mqxfa_name = 'mqy.4r1.b1'
1717

1818
# Calculate n1 with the ``rays`` method
1919
sig_rays, tw_rays, aper_rays, _ = aperture_model.get_aperture_sigmas_at_element(
20-
line_name="b1",
2120
element_name=mqxfa_name,
2221
resolution=0.1,
2322
cross_sections_num_points=100,
@@ -26,7 +25,6 @@
2625

2726
# Calculate n1's with the ``bisection`` method
2827
sig_bisect, tw_bisect, aper_bisect, max_envelope = aperture_model.get_aperture_sigmas_at_element(
29-
line_name="b1",
3028
element_name=mqxfa_name,
3129
resolution=0.1,
3230
cross_sections_num_points=100,
@@ -35,7 +33,6 @@
3533

3634
# Get envelope at arbitrary sigma
3735
envelopes, aper_envel, tw_envel = aperture_model.get_apertures_and_envelope_at_element(
38-
line_name='b1',
3936
element_name=mqxfa_name,
4037
resolution=0.1,
4138
sigmas=1,
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import xtrack as xt
2+
import xobjects as xo
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from cgeom.aperture import Aperture, transform_matrix
6+
from cgeom.structures import ApertureModel, ApertureType, Circle, Profile, ProfilePosition, Rectangle, TypePosition
7+
8+
9+
env = xt.Environment()
10+
11+
l = 1
12+
dx = 1
13+
angle = np.deg2rad(30)
14+
l_straight = dx / np.sin(angle / 2)
15+
rho = 0.5 * l_straight / np.sin(angle / 2)
16+
l_curv = rho * angle
17+
18+
drift = env.new('drift', xt.Drift, length=l)
19+
rot_plus = env.new('rot_plus', xt.Bend, length=l_curv, angle=angle, k0=0)
20+
rot_minus = env.new('rot_minus', xt.Bend, length=l_curv, angle=-angle, k0=0)
21+
22+
line = env.new_line(
23+
name='line',
24+
components=[drift, rot_plus, drift, drift, rot_minus, drift],
25+
)
26+
27+
sv = line.survey()
28+
29+
circle = Circle(radius=2)
30+
rectangle = Rectangle(half_width=2, half_height=0.5)
31+
32+
profiles = [
33+
Profile(shape=circle, tol_r=0, tol_x=0, tol_y=0),
34+
Profile(shape=rectangle, tol_r=0, tol_x=0, tol_y=0),
35+
]
36+
37+
profile_positions = [
38+
ProfilePosition(profile_index=0, s_position=s)
39+
for s in [0, 11]
40+
]
41+
42+
types = [
43+
ApertureType(curvature=0., positions=profile_positions),
44+
]
45+
46+
type_positions = [
47+
TypePosition(
48+
type_index=0,
49+
survey_reference_name='drift::0',
50+
survey_index=sv.name.tolist().index('drift::0'),
51+
transformation=transform_matrix(dx=-1.5),
52+
),
53+
]
54+
55+
model = ApertureModel(
56+
line_name='line',
57+
type_positions=type_positions,
58+
types=types,
59+
profiles=profiles,
60+
type_names=['type0'],
61+
profile_names=['circle', 'rectangle'],
62+
)
63+
64+
ax = plt.figure().add_subplot(projection='3d')
65+
ax.plot(sv.Z, sv.X, sv.Y, c='b')
66+
ax.set_xlabel('Z [m]')
67+
ax.set_ylabel('X [m]')
68+
ax.set_zlabel('Y [m]')
69+
70+
ax.auto_scale_xyz([0, 12], [-6, 6], [-6, 6])
71+
72+
aper = Aperture(env, model, cross_sections=None)
73+
74+
75+
def matrix_from_survey_point(sv_row):
76+
matrix = np.identity(4)
77+
matrix[:3, 0] = sv_row.ex
78+
matrix[:3, 1] = sv_row.ey
79+
matrix[:3, 2] = sv_row.ez
80+
matrix[:3, 3] = np.hstack([sv_row.X, sv_row.Y, sv_row.Z])
81+
return matrix
82+
83+
84+
def poly2d_to_hom(poly2d):
85+
num_points = poly2d.shape[0]
86+
poly_hom = np.column_stack((poly2d, np.zeros(num_points), np.ones(num_points))).T
87+
return poly_hom
88+
89+
90+
for type_pos in aper.model.type_positions:
91+
aper_type = aper.model.type_for_position(type_pos)
92+
sv_ref = sv.rows[type_pos.survey_index]
93+
94+
sv_ref_matrix = matrix_from_survey_point(sv_ref)
95+
type_matrix = type_pos.transformation.to_nparray()
96+
97+
for profile_pos in aper_type.positions:
98+
profile = aper.model.profile_for_position(profile_pos)
99+
100+
num_points = 100
101+
poly = aper.polygon_for_profile(profile, num_points)
102+
poly_hom = poly2d_to_hom(poly)
103+
104+
profile_position_matrix = transform_matrix(
105+
dx=profile_pos.shift_x,
106+
dy=profile_pos.shift_y,
107+
ds=profile_pos.s_position,
108+
theta=profile_pos.rot_y,
109+
phi=profile_pos.rot_x,
110+
psi=profile_pos.rot_z,
111+
)
112+
113+
poly_in_sv_frame = sv_ref_matrix @ type_matrix @ profile_position_matrix @ poly_hom
114+
115+
xs, ys, zs = poly_in_sv_frame[:3]
116+
ax.plot(zs, xs, ys, c='r')
117+
118+
119+
def tangents_at_s(line, s_positions):
120+
"""Return a local coordinate system (each represented by a homogeneous matrix) at all ``s_positions``."""
121+
tangents = np.zeros(shape=(len(s_positions), 4, 4), dtype=np.float32)
122+
line_sliced = line.copy()
123+
line_sliced.cut_at_s(s_positions)
124+
survey_sliced = line_sliced.survey()
125+
sv_indices = np.searchsorted(survey_sliced.s, s_positions)
126+
127+
for idx, sv_idx in enumerate(sv_indices):
128+
row = survey_sliced.rows[sv_idx]
129+
tangents[idx, :3, 0] = row.ex
130+
tangents[idx, :3, 1] = row.ey
131+
tangents[idx, :3, 2] = row.ez
132+
tangents[idx, :, 3] = np.hstack([row.X, row.Y, row.Z, 1])
133+
134+
return tangents
135+
136+
137+
s_for_cuts = np.linspace(1, 11, 20)
138+
profiles, tangents = aper.profiles_at_s('line', s_for_cuts)
139+
tangents2 = tangents_at_s(line, s_for_cuts)
140+
141+
xo.assert_allclose(tangents, tangents2, atol=1e-6, rtol=1e-6)
142+
143+
for idx, s in enumerate(s_for_cuts):
144+
profile = profiles[idx]
145+
profile_hom = poly2d_to_hom(profile)
146+
profile_in_sv_frame = tangents[idx] @ profile_hom
147+
profile_in_sv_frame2 = tangents2[idx] @ profile_hom
148+
149+
xs, ys, zs = profile_in_sv_frame[:3]
150+
ax.plot(zs, xs, ys, c='g')
151+
152+
plt.show()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ include-package-data = true
6161
[project.entry-points.xobjects]
6262
include = "xtrack"
6363

64-
[pytest]
64+
[tool.pytest]
6565
markers = [
6666
"context_dependent: marks test as one that depends on the execution context",
6767
]

xtrack/aperture/aperture.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,20 @@ def __init__(
8383
self,
8484
line: Line,
8585
model: ApertureModel,
86-
cross_sections,
87-
halo_params=None,
86+
num_profile_points: int = 128,
87+
halo_params: Optional[dict] = None,
8888
context: Optional[XContext] = None,
8989
s_tol=1e-3,
9090
):
9191
self.line = line
9292
self.model = model # positioning of types in line frame
93-
self.cross_sections = cross_sections
9493
self.halo_params = self.halo_params.copy()
9594
self.context = context or xo.ContextCpu()
9695
self.s_tol = s_tol
9796

97+
self.num_profile_points = num_profile_points
98+
self._cross_sections: Optional[CrossSections] = None
99+
98100
if halo_params is not None:
99101
self.halo_params.update(halo_params)
100102

@@ -407,7 +409,6 @@ def _build_aperture_model(
407409
aperture = cls(
408410
line=line,
409411
model=model,
410-
cross_sections=None,
411412
context=context,
412413
)
413414

@@ -481,8 +482,6 @@ def get_aperture_sigmas_at_s(
481482
line_sliced.cut_at_s(s_positions)
482483
s_start, s_end = s_positions[0], s_positions[-1]
483484

484-
self.cross_sections = self._build_cross_sections(cross_sections_num_points)
485-
486485
sliced_twiss = line_sliced.twiss(init=twiss_init).rows[s_start:s_end:'s']
487486

488487
num_slices = len(sliced_twiss.s)
@@ -551,8 +550,6 @@ def get_apertures_and_envelope_at_s(
551550
line_sliced.cut_at_s(s_positions)
552551
s_start, s_end = s_positions[0], s_positions[-1]
553552

554-
self.cross_sections = self._build_cross_sections(cross_sections_num_points)
555-
556553
sliced_twiss = line_sliced.twiss(init=twiss_init).rows[s_start:s_end:'s']
557554
num_slices = len(sliced_twiss.s)
558555
twiss_data = TwissData.from_twiss_table(self.line.particle_ref, sliced_twiss)
@@ -607,6 +604,13 @@ def profiles_at_s(self, s_positions: Collection[float]) -> Tuple[PolygonPoints32
607604
)
608605
return placeholders, sv_sliced.tangent.to_nparray()
609606

607+
@property
608+
def cross_sections(self) -> CrossSections:
609+
if not self._cross_sections:
610+
self._cross_sections = self._build_cross_sections(self.num_profile_points)
611+
612+
return self._cross_sections
613+
610614
def _get_cuts_at_element(self, element_name: str, resolution: Optional[float]) -> List[float]:
611615
"""Get list of s positions so that the element ``element_name`` is cut with a ``resolution``."""
612616
element = self.line[element_name]

xtrack/aperture/structures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
self,
146146
type_positions: List[TypePosition],
147147
types: List[ApertureType],
148-
profiles: List[Profile],
148+
profiles: List[ShapeTypes],
149149
type_names: List[str],
150150
profile_names: List[str],
151151
**kwargs,

0 commit comments

Comments
 (0)