diff --git a/tests/test_madnginterface.py b/tests/test_madnginterface.py index 8c21c6006..0643b5cde 100644 --- a/tests/test_madnginterface.py +++ b/tests/test_madnginterface.py @@ -3,6 +3,7 @@ import xobjects as xo import pathlib import numpy as np +from xtrack._temp import lhc_match as lm test_data_folder = pathlib.Path( __file__).parent.joinpath('../test_data').absolute() @@ -68,7 +69,7 @@ def test_madng_interface_with_multipole_errors_and_misalignments(): line[nn_quad].shift_y = sy * line.ref['on_error'] line[nn_quad].rot_s_rad = rr * line.ref['on_error'] line[nn_quad].knl[2] = kkk * line.ref['on_error'] - tw = line.madng_twiss() + tw = line.madng_twiss(coupling_edw_teng=True, compute_chromatic_properties=True) xo.assert_allclose(tw.x, tw.x_ng, atol=5e-4*tw.x.std(), rtol=0) xo.assert_allclose(tw.y, tw.y_ng, atol=5e-4*tw.y.std(), rtol=0) @@ -82,7 +83,7 @@ def test_madng_interface_with_multipole_errors_and_misalignments(): xo.assert_allclose(tw.by_chrom, tw.by_ng, atol=5e-3*tw.wy_chrom.max(), rtol=0) line['on_error'] = 0 - tw = line.madng_twiss() + tw = line.madng_twiss(coupling_edw_teng=True, compute_chromatic_properties=True) xo.assert_allclose(tw.x, 0, atol=1e-10, rtol=0) xo.assert_allclose(tw.y, 0, atol=1e-10, rtol=0) xo.assert_allclose(tw.betx2, 0, atol=1e-10, rtol=0) @@ -173,7 +174,7 @@ def test_madng_interface_with_slicing(): line.cut_at_s(np.arange(1000)) tw_xs = line.twiss4d() - tw = line.madng_twiss() + tw = line.madng_twiss(coupling_edw_teng=True, compute_chromatic_properties=True) assert len(tw) == len(tw_xs) @@ -196,16 +197,16 @@ def test_madng_twiss_with_initial_conditions(): line = xt.load(test_data_folder / 'hllhc15_thick/lhc_thick_with_knobs.json') #pytest.set_trace() - tw_xs = line.twiss(betx=120, bety=150) - tw = line.madng_twiss(beta11=120, beta22=150) + tw_xs = line.twiss(betx=120, bety=150, alfx=5, alfy=5, dx=1e-4) + tw = line.madng_twiss(beta11=120, beta22=150, alfa11=5, alfa22=5, dx=1e-4) assert len(tw) == len(tw_xs) assert len(tw.betx) == len(tw.beta11_ng) - xo.assert_allclose(tw.betx, tw.beta11_ng, rtol=1e-7, atol=1e-6) - xo.assert_allclose(tw.bety, tw.beta22_ng, rtol=1e-7, atol=1e-6) - xo.assert_allclose(tw.alfx, tw.alfa11_ng, rtol=1e-7, atol=1e-6) - xo.assert_allclose(tw.alfy, tw.alfa22_ng, rtol=1e-7, atol=1e-6) + xo.assert_allclose(tw.betx, tw.beta11_ng, rtol=1e-6, atol=1e-6) + xo.assert_allclose(tw.bety, tw.beta22_ng, rtol=1e-6, atol=1e-6) + xo.assert_allclose(tw.alfx, tw.alfa11_ng, rtol=1e-6, atol=1e-6) + xo.assert_allclose(tw.alfy, tw.alfa22_ng, rtol=1e-6, atol=1e-6) xo.assert_allclose(tw.dx, tw.dx_ng, rtol=1e-8, atol=1e-6) xo.assert_allclose(tw.dy, tw.dy_ng, rtol=1e-8, atol=1e-6) xo.assert_allclose(tw.dpx, tw.dpx_ng, rtol=1e-8, atol=1e-6) @@ -213,8 +214,8 @@ def test_madng_twiss_with_initial_conditions(): xo.assert_allclose(tw.x, tw.x_ng, rtol=1e-8, atol=1e-6) xo.assert_allclose(tw.y, tw.y_ng, rtol=1e-8, atol=1e-6) - tw2_xs = line.twiss(start='s.ds.l8.b1', end='ip1', betx=100, bety=34) - tw2_xsng = line.madng_twiss(start='s.ds.l8.b1', end='ip1', beta11=100, beta22=34, xsuite_tw=False) + tw2_xs = line.twiss(start='s.ds.l8.b1', end='ip1', betx=100, bety=34, dx=1e-5) + tw2_xsng = line.madng_twiss(start='s.ds.l8.b1', end='ip1', beta11=100, beta22=34, dx=1e-5, xsuite_tw=False) assert len(tw2_xs.betx) == len(tw2_xsng.beta11_ng) xo.assert_allclose(tw2_xs.betx, tw2_xsng.beta11_ng, rtol=1e-8, atol=1e-6) @@ -248,16 +249,34 @@ def test_madng_twiss_with_initial_conditions(): xo.assert_allclose(tw3_xsng.alfx, tw3_xsng.alfa11_ng, rtol=1e-8, atol=1e-6) xo.assert_allclose(tw3_xsng.alfy, tw3_xsng.alfa22_ng, rtol=1e-8, atol=1e-6) + tw4_xs = line.twiss(start='ip3', end='ip4', betx=121.5668, bety=218.58374, alfx=2.295, alfy=-2.6429, dx=-0.51) + tw4_xsng = line.madng_twiss(start='ip3', end='ip4', beta11=121.5668, beta22=218.58374, alfa11=2.295, + alfa22=-2.6429, dx=-0.51, xsuite_tw=False) + + assert len(tw4_xs.betx) == len(tw4_xsng.beta11_ng) + xo.assert_allclose(tw4_xs.betx, tw4_xsng.beta11_ng, rtol=1e-6, atol=1e-5) + xo.assert_allclose(tw4_xs.bety, tw4_xsng.beta22_ng, rtol=1e-6, atol=1e-5) + xo.assert_allclose(tw4_xs.alfx, tw4_xsng.alfa11_ng, rtol=1e-6, atol=1e-5) + xo.assert_allclose(tw4_xs.alfy, tw4_xsng.alfa22_ng, rtol=1e-6, atol=1e-5) + xo.assert_allclose(tw4_xs.dx, tw4_xsng.dx_ng, rtol=1e-7, atol=1e-8) + xo.assert_allclose(tw4_xs.dy, tw4_xsng.dy_ng, rtol=1e-7, atol=1e-8) + xo.assert_allclose(tw4_xs.x, tw4_xsng.x_ng, rtol=1e-8, atol=1e-10) + xo.assert_allclose(tw4_xs.y, tw4_xsng.y_ng, rtol=1e-8, atol=1e-10) + xo.assert_allclose(tw4_xs.px, tw4_xsng.px_ng, rtol=1e-8, atol=1e-10) + xo.assert_allclose(tw4_xs.py, tw4_xsng.py_ng, rtol=1e-8, atol=1e-10) + xo.assert_allclose(tw4_xs.mux, tw4_xsng.mu1_ng, rtol=1e-8, atol=1e-5) + xo.assert_allclose(tw4_xs.muy, tw4_xsng.mu2_ng, rtol=1e-8, atol=1e-5) + def test_madng_slices(): line = xt.load(test_data_folder / 'hllhc15_thick/lhc_thick_with_knobs.json') tw = line.twiss4d() - twng = line.madng_twiss() + twng = line.madng_twiss(compute_chromatic_properties=True) line.cut_at_s(np.linspace(0, line.get_length(), 5000)) tw_sliced = line.twiss4d() - twng_sliced = line.madng_twiss() + twng_sliced = line.madng_twiss(compute_chromatic_properties=True) tt_sliced = line.get_table() assert np.all(np.array(sorted(list(set(tt_sliced.element_type)))) == @@ -302,3 +321,201 @@ def test_madng_slices(): xo.assert_allclose(twng_ip.wy_ng, twng_ip_sliced.wy_ng, rtol=1e-3) xo.assert_allclose(twng_ip.dx_ng, twng_ip_sliced.dx_ng, atol=1e-6) xo.assert_allclose(twng_ip.dy_ng, twng_ip_sliced.dy_ng, atol=1e-6) + +def test_madng_match_optics(): + collider = xt.Environment.from_json(test_data_folder / + 'hllhc15_thick/hllhc15_collider_thick.json') + collider.vars.load_madx(test_data_folder / + 'hllhc15_thick/opt_round_150_1500.madx') + + line = collider.lhcb1 + tw0 = line.madng_twiss() + + lm.set_var_limits_and_steps(collider) + + # Match with Xsuite Targets + opt = line.match( + solve=False, + default_tol={None: 1e-8, 'betx': 1e-6, 'bety': 1e-6, 'alfx': 1e-6, 'alfy': 1e-6}, + start='s.ds.l8.b1', end='ip1', + init=tw0, init_at=xt.START, + vary=[ + # Only IR8 quadrupoles including DS + xt.VaryList(['kq6.l8b1', 'kq7.l8b1', 'kq8.l8b1', 'kq9.l8b1', 'kq10.l8b1', + 'kqtl11.l8b1', 'kqt12.l8b1', 'kqt13.l8b1', + 'kq4.l8b1', 'kq5.l8b1', 'kq4.r8b1', 'kq5.r8b1', + 'kq6.r8b1', 'kq7.r8b1', 'kq8.r8b1', 'kq9.r8b1', + 'kq10.r8b1', 'kqtl11.r8b1', 'kqt12.r8b1', 'kqt13.r8b1'])], + targets=[ + xt.TargetSet(at='ip8', tars=('betx', 'bety', 'alfx', 'alfy', 'dx', 'dpx'), value=tw0, weight=1), + xt.TargetSet(at='ip1', betx=0.15, bety=0.1, alfx=0, alfy=0, dx=0, dpx=0, weight=1), + xt.TargetRelPhaseAdvance('mux', value = tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + xt.TargetRelPhaseAdvance('muy', value = tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + ], + use_tpsa=True) + + opt.step(30) + + assert opt._err.call_counter < 20 + assert len(opt.log()) < 10 + + tw = line.twiss(init=tw0, start='s.ds.l8.b1', end='ip1') + + xo.assert_allclose(tw['betx', 'ip1'], 0.15, atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip1'], 0.1, atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip1'], 0., atol=1e-6, rtol=0) + + xo.assert_allclose(tw['betx', 'ip8'], tw0['betx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip8'], tw0['bety', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip8'], tw0['alfx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip8'], tw0['alfy', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip8'], tw0['dx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip8'], tw0['dy', 'ip8'], atol=1e-6, rtol=0) + + xo.assert_allclose(tw['mux', 'ip1.l1'] - tw['mux', 's.ds.l8.b1'], tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['muy', 'ip1.l1'] - tw['muy', 's.ds.l8.b1'], tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1'], atol=1e-6, rtol=0) + + opt.reload(0) + opt.actions[0].cleanup() + + # Match with MAD-NG and Xsuite Targets mixed + opt = line.match( + solve=False, + default_tol={None: 1e-8, 'betx': 1e-6, 'bety': 1e-6, 'alfx': 1e-6, 'alfy': 1e-6}, + start='s.ds.l8.b1', end='ip1', + init=tw0, init_at=xt.START, + vary=[ + # Only IR8 quadrupoles including DS + xt.VaryList(['kq6.l8b1', 'kq7.l8b1', 'kq8.l8b1', 'kq9.l8b1', 'kq10.l8b1', + 'kqtl11.l8b1', 'kqt12.l8b1', 'kqt13.l8b1', + 'kq4.l8b1', 'kq5.l8b1', 'kq4.r8b1', 'kq5.r8b1', + 'kq6.r8b1', 'kq7.r8b1', 'kq8.r8b1', 'kq9.r8b1', + 'kq10.r8b1', 'kqtl11.r8b1', 'kqt12.r8b1', 'kqt13.r8b1'])], + targets=[ + xt.TargetSet(at='ip8', tars=('beta11_ng', 'bety', 'alfa11_ng', 'alfy', 'dx_ng', 'dpx'), value=tw0, weight=1), + xt.TargetSet(at='ip1', betx=0.15, beta22_ng=0.1, alfx=0, alfa22_ng=0, dx=0, dpx_ng=0, weight=1), + xt.TargetRelPhaseAdvance('mux', value = tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + xt.TargetRelPhaseAdvance('mu2_ng', value = tw0['mu2_ng', 'ip1.l1'] - tw0['mu2_ng', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + ], + use_tpsa=True) + + opt.step(30) + + assert opt._err.call_counter < 20 + assert len(opt.log()) < 10 + + tw = line.twiss(init=tw0, start='s.ds.l8.b1', end='ip1') + + xo.assert_allclose(tw['betx', 'ip1'], 0.15, atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip1'], 0.1, atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip1'], 0., atol=1e-6, rtol=0) + + xo.assert_allclose(tw['betx', 'ip8'], tw0['betx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip8'], tw0['bety', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip8'], tw0['alfx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip8'], tw0['alfy', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip8'], tw0['dx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip8'], tw0['dy', 'ip8'], atol=1e-6, rtol=0) + + xo.assert_allclose(tw['mux', 'ip1.l1'] - tw['mux', 's.ds.l8.b1'], tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['muy', 'ip1.l1'] - tw['muy', 's.ds.l8.b1'], tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1'], atol=1e-6, rtol=0) + + opt.reload(0) + opt.actions[0].cleanup() + + # Match on full line without initial conditions + opt = line.match( + solve=False, + default_tol={None: 1e-8, 'betx': 1e-6, 'bety': 1e-6, 'alfx': 1e-6, 'alfy': 1e-6}, + vary=[ + # Only IR8 quadrupoles including DS + xt.VaryList(['kq6.l8b1', 'kq7.l8b1', 'kq8.l8b1', 'kq9.l8b1', 'kq10.l8b1', + 'kqtl11.l8b1', 'kqt12.l8b1', 'kqt13.l8b1', + 'kq4.l8b1', 'kq5.l8b1', 'kq4.r8b1', 'kq5.r8b1', + 'kq6.r8b1', 'kq7.r8b1', 'kq8.r8b1', 'kq9.r8b1', + 'kq10.r8b1', 'kqtl11.r8b1', 'kqt12.r8b1', 'kqt13.r8b1'])], + targets=[ + xt.TargetSet(at='ip8', tars=('beta11_ng', 'beta22_ng', 'alfa11_ng', 'alfa22_ng', 'dx_ng', 'dpx_ng'), value=tw0, weight=1), + xt.TargetSet(at='ip1.l1', beta11_ng=0.15, beta22_ng=0.1, alfa11_ng=0, alfa22_ng=0, dx_ng=0, dpx_ng=0, weight=1), + xt.TargetRelPhaseAdvance('mu1_ng', value = tw0['mu1_ng', 'ip1.l1'] - tw0['mu1_ng', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + xt.TargetRelPhaseAdvance('mu2_ng', value = tw0['mu2_ng', 'ip1.l1'] - tw0['mu2_ng', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + ], + use_tpsa=True) + + opt.step(30) + + assert opt._err.call_counter < 20 + assert len(opt.log()) < 10 + + tw = line.twiss(init=tw0) + + xo.assert_allclose(tw['betx', 'ip1.l1'], 0.15, atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip1.l1'], 0.1, atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip1.l1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip1.l1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip1.l1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip1.l1'], 0., atol=1e-6, rtol=0) + + xo.assert_allclose(tw['betx', 'ip8'], tw0['betx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip8'], tw0['bety', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip8'], tw0['alfx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip8'], tw0['alfy', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip8'], tw0['dx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip8'], tw0['dy', 'ip8'], atol=1e-6, rtol=0) + + xo.assert_allclose(tw['mux', 'ip1.l1'] - tw['mux', 's.ds.l8.b1'], tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['muy', 'ip1.l1'] - tw['muy', 's.ds.l8.b1'], tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1'], atol=1e-6, rtol=0) + +def test_madng_orbit_bump(): + env = xt.Environment() + env.vars.default_to_zero = True + line = env.new_line(length=10, components=[ + env.new('corr1', xt.Multipole, isthick=True, + knl=['kick_h_1'], ksl=['kick_v_1'], length=0.1, at=1), + env.new('corr2', xt.Multipole, isthick=True, + knl=['kick_h_2'], ksl=['kick_v_2'], length=0.1, at=2), + env.new('corr3', xt.Multipole, isthick=True, + knl=['kick_h_3'], ksl=['kick_v_3'], length=0.1, at=8), + env.new('corr4', xt.Multipole, isthick=True, + knl=['kick_h_4'], ksl=['kick_v_4'], length=0.1, at=9), + env.new('mid', xt.Marker, at=5), + env.new('end', xt.Marker, at=10) + ]) + line.set_particle_ref('proton', p0c=26e9) + + opt = line.match( + solve=False, + betx=1, bety=1, + vary=xt.VaryList(['kick_h_1', 'kick_v_1', + 'kick_h_2', 'kick_v_2', + 'kick_h_3', 'kick_v_3', + 'kick_h_4', 'kick_v_4']), + targets=[ + xt.TargetSet(x=1e-3, y=-2e-3, px=0, py=0, at='mid'), + xt.TargetSet(x=0, y=0, px=0, py=0, at='end'), + ], + use_tpsa=True + ) + + jac_ng = opt._err.get_jacobian(opt._err._get_x()) + + jac_opt = np.array([[-40, 0, -30, 0, 0, 0, 0, 0], + [-100, 0, -100, 0, 0, 0, 0, 0], + [0, 40, 0, 30, 0, 0, 0, 0], + [0, 100, 0, 100, 0, 0, 0, 0], + [-90, 0, -80, 0, -20, 0, -10, 0], + [-100, 0, -100, 0, -100, 0, -100, 0], + [0, 90, 0, 80, 0, 20, 0, 10], + [0, 100, 0, 100, 0, 100, 0, 100]]) + + xo.assert_allclose(jac_ng, jac_opt, rtol=1e-12, atol=1e-12) + + opt.solve() + + assert opt._err.call_counter < 7 \ No newline at end of file diff --git a/xtrack/mad_writer.py b/xtrack/mad_writer.py index a9fe4a231..6c7cf59a7 100644 --- a/xtrack/mad_writer.py +++ b/xtrack/mad_writer.py @@ -543,6 +543,7 @@ def srotation_to_mad_str(eref, mad_type=MadType.MADX, substituted_vars=None): def element_to_mad_str( name, + env_name, line, mad_type=MadType.MADX, substituted_vars=None, @@ -551,8 +552,9 @@ def element_to_mad_str( Generic converter for elements to MADX/MAD-NG. """ - el = line.element_dict[name] - eref = _get_eref(line, name) + el = line.element_dict[env_name] + eref = _get_eref(line, env_name) + parent_flag = hasattr(el, '_parent') if el.__class__ == xt.Marker or parent_flag and el._parent.__class__ == xt.Marker: @@ -572,7 +574,7 @@ def element_to_mad_str( _handle_transforms(tokens, eref, mad_type=mad_type, substituted_vars=substituted_vars) if mad_type == MadType.MADNG: - tokens = [tokens[0]] + [f"'{name.replace(':', '__')}'"] + tokens[1:] + tokens = [tokens[0]] + [f"'{name}'"] + tokens[1:] tokens = _handle_tokens_madng(tokens, substituted_vars) return ', '.join(tokens) @@ -603,18 +605,15 @@ def to_madx_sequence(line, name='seq', mode='sequence'): tt_name = tt.name tt_s = tt.s tt_isthick = tt.isthick - for ii in range(len(tt.name)): + for ii in range(len(tt.name[:-1])): nn = tt_name[ii] if not(tt_isthick[ii]): s_dict[nn] = tt_s[ii] else: s_dict[nn] = 0.5 * (tt_s[ii] + tt_s[ii+1]) - for nn in line.element_names: - - - el = line.element_dict[nn] - el_str = element_to_mad_str(nn, line, mad_type=MadType.MADX) + el = line.element_dict[tt.env_name[ii]] + el_str = element_to_mad_str(nn, tt.env_name[ii], line, mad_type=MadType.MADX) if nn + '_tilt_entry' in line.element_dict: el_str += ", " + mad_assignment('tilt', _ge(line.element_refs[nn + '_tilt_entry'].angle) / 180. * np.pi, @@ -672,16 +671,16 @@ def to_madng_sequence(line, name='seq', mode='sequence'): else: s_dict[nn] = 0.5 * (tt.s[ii] + tt.s[ii+1]) - el = line.element_dict[nn] + el = line.element_dict[tt.env_name[ii]] - el_str = element_to_mad_str(nn, line, mad_type=MadType.MADNG, substituted_vars=substituted_vars) + el_str = element_to_mad_str(nn, tt.env_name[ii], line, mad_type=MadType.MADNG, substituted_vars=substituted_vars) if el_str is None: continue # Misalignments if hasattr(el, 'shift_x') and hasattr(el, 'shift_y'): - el_str += f", misalign =\\ {{dx={mad_str_or_value(_ge(line.ref[nn].shift_x))}, dy={mad_str_or_value(_ge(line.ref[nn].shift_y))}}}" + el_str += f", misalign =\\ {{dx={mad_str_or_value(_ge(line.ref[tt.env_name[ii]].shift_x))}, dy={mad_str_or_value(_ge(line.ref[tt.env_name[ii]].shift_y))}}}" el_strs.append(el_str) # Chunking sequence diff --git a/xtrack/madng_interface.py b/xtrack/madng_interface.py index dce0b70d1..9e4b1f68d 100644 --- a/xtrack/madng_interface.py +++ b/xtrack/madng_interface.py @@ -1,11 +1,13 @@ import numpy as np + from .match import Action import os import uuid -from .mad_writer import mad_str_or_value import xtrack as xt +from xtrack.particles.particles import ptau2delta, dptau2ddelta + NG_XS_MAP = { 'beta11': 'betx', 'beta22': 'bety', @@ -15,14 +17,56 @@ 'mu2': 'muy', } +XS_NG_MAP = { + 'betx': 'beta11', + 'bety': 'beta22', + 'alfx': 'alfa11', + 'alfy': 'alfa22', + 'mux': 'mu1', + 'muy': 'mu2', + 'dx': 'dx', + 'dpx': 'dpx', + 'x': 'x', + 'px': 'px', + 'y': 'y', + 'py': 'py', + 'zeta': 't', + 'delta': 'pt', + 'ptau': 'pt', +} + BETA0_COLUMNS = ['x', 'px', 'y', 'py', 't', 'pt', - 'dx', 'dy', 'dpx', 'dpy', 'ddx', 'ddpx', 'ddy', 'ddpy', 'wx', 'phix', - 'wy', 'phiy', 'mu1', 'mu2', 'mu3', 'dmu1', 'dmu2', 'dmu3', 'r11', - 'r12', 'r21', 'r22', 'alfa11', 'alfa12', 'alfa13', 'alfa21', - 'alfa22', 'alfa23', 'alfa31', 'alfa32', 'alfa33', 'beta11', - 'beta12', 'beta13', 'beta21', 'beta22', 'beta23', 'beta31', - 'beta32', 'beta33', 'gama11', 'gama12', 'gama13', 'gama21', - 'gama22', 'gama23', 'gama31', 'gama32', 'gama33'] + 'dx', 'dy', 'dpx', 'dpy', 'ddx', 'ddpx', 'ddy', 'ddpy', 'wx', 'phix', + 'wy', 'phiy', 'mu1', 'mu2', 'mu3', 'dmu1', 'dmu2', 'dmu3', 'r11', + 'r12', 'r21', 'r22', 'alfa11', 'alfa12', 'alfa13', 'alfa21', + 'alfa22', 'alfa23', 'alfa31', 'alfa32', 'alfa33', 'beta11', + 'beta12', 'beta13', 'beta21', 'beta22', 'beta23', 'beta31', + 'beta32', 'beta33', 'gama11', 'gama12', 'gama13', 'gama21', + 'gama22', 'gama23', 'gama31', 'gama32', 'gama33'] + +TW_BASE_COLUMNS = ['s', 'beta11', 'beta22', 'beta33', 'alfa11', 'alfa22', 'alfa33', + 'gama11', 'gama22', 'gama33', 'x', 'px', 'y', 'py', 't', 'pt', + 'dx', 'dy', 'dpx', 'dpy', 'mu1', 'mu2', 'mu3'] + +OPTFUN_QUANTITIES = ['beta11', 'beta22', 'alfa11', 'alfa22', 'gama11', 'gama22', + 'dx', 'dy', 'dpx', 'dpy', 'mu1', 'mu2'] + +CHROM_COLUMNS = ['dmu1', 'dmu2', 'dmu3', 'Dx', 'Dpx', 'Dy', + 'Dpy', 'ddx', 'ddpx', 'ddy', 'ddpy', 'wx', 'wy', 'phix', 'phiy'] + +COUPLING_COLUMNS = ['alfa12', 'alfa13', 'alfa21', 'alfa23', 'alfa31', 'alfa32', + 'beta12', 'beta13', 'beta21', 'beta23', 'beta31', 'beta32', + 'gama12', 'gama13', 'gama21', 'gama23', 'gama31', 'gama32', + 'f1001', 'f1010', 'r11', 'r12', 'r21', 'r22'] + +TPSA_ALLOWED_TARGETS = { 'beta11', 'beta22', 'alfa11', 'alfa22', 'dx', 'dpx', 'dy', 'dpy', + 'mu1', 'mu2', 'x', 'px', 'y', 'py', 't', 'pt' } + +XSUITE_MADNG_ENV_NAME = "_xsuite_matching_env" + +def _lua_list(strings): + """Convert Python list → Lua { 'a', 'b', 'c' }.""" + return "{ " + ", ".join(f"'{s}'" for s in strings) + " }" class MadngVars: @@ -30,7 +74,15 @@ def __init__(self, mad): self.mad = mad def __setitem__(self, key, value): - setattr(self.mad.MADX, key.replace('.', '_'), value) + # Check for key if it's a ctpsa or tpsa + var = f"MADX['{key.replace('.', '_')}']" + is_tpsa = self.mad.send(f"py:send(MAD.typeid.is_tpsa({var}) or MAD.typeid.is_ctpsa({var}))").recv() + if is_tpsa: + self.mad.send(f"{var}:set0(py:recv())").send(value) + else: + self.mad[var] = value + + #Expressions still to be handled, could use the following: # mng.send( # MADX:open_env() @@ -92,25 +144,29 @@ def _build_rdt_script(mng_sequence_name, rdts, columns): def _build_beta0_block_string(tw_kwargs): flag_init = False - beta0_keys = [] + beta0_dict = {} for k in tw_kwargs.keys(): if k in BETA0_COLUMNS: - beta0_keys.append(k) + beta0_dict[k] = tw_kwargs[k] + flag_init = True + elif k in XS_NG_MAP: + beta0_dict[XS_NG_MAP[k]] = tw_kwargs[k] flag_init = True if flag_init: # Construct beta0 string beta0_str = 'X0 = beta0 {' - for k in beta0_keys: - beta0_str += f'{k} = {tw_kwargs[k]}, ' + for k, v in beta0_dict.items(): + beta0_str += f'{k} = {v}, ' beta0_str = beta0_str[:-2] + '}, ' else: beta0_str = '' return beta0_str -def _tw_ng(line, rdts=(), normal_form=True, +def _tw_ng(line, rdts=(), normal_form=False, mapdef_twiss=2, mapdef_normal_form=4, - nslice=3, xsuite_tw=True, X0=None, **tw_kwargs): + nslice=3, xsuite_tw=True, X0=None, compute_chromatic_properties=False, + coupling_edw_teng=False, **tw_kwargs): _action = ActionTwissMadng(line, { "rdts": rdts, @@ -134,25 +190,24 @@ def _tw_ng(line, rdts=(), normal_form=True, raise NotImplementedError('TwissTable as init not implemented.') X0_str = _build_beta0_block_string(tw_kwargs) else: - X0_str = f'X0={X0}, ' + X0_str = f'X0 = {X0}, ' + + if (start is None) != (end is None): + raise ValueError('Start and end must be specified together.') - if not (start is None and end is None and init is None) \ - and not (start is not None and end is not None and X0_str != ''): - raise ValueError('Start and end must be specified together, as well as initial conditions, if open twiss is used.') + if start is not None and end is not None and not X0_str: + raise ValueError('Initial conditions must be specified when start and end are given.') full_twiss_str = '' - tw_columns = ['s', 'beta11', 'beta22', 'alfa11', 'alfa22', - 'x', 'px', 'y', 'py', 't', 'pt', - 'dx', 'dy', 'dpx', 'dpy', 'mu1', 'mu2'] - if start is None and end is None: - extended_tw_columns = ['beta12', 'beta21', 'alfa12', 'alfa21', - 'wx', 'wy', 'phix', 'phiy', 'dmu1', 'dmu2', - 'f1001', 'f1010', 'r11', 'r12', 'r21', 'r22', - ] - full_twiss_str = f"mapdef={mapdef_twiss}, implicit=true, nslice={nslice}, misalgn=true, coupling=true, chrom=true" - tw_columns += extended_tw_columns + tw_columns = TW_BASE_COLUMNS.copy() + + full_twiss_str = f"implicit=true, nslice={nslice}, misalign=true, coupling={str(coupling_edw_teng).lower()}, chrom={str(compute_chromatic_properties).lower()}" + if coupling_edw_teng: + tw_columns += COUPLING_COLUMNS + if compute_chromatic_properties: + tw_columns += CHROM_COLUMNS columns = tw_columns + list(rdts) send_cmd = _build_column_send_script(columns) @@ -160,7 +215,6 @@ def _tw_ng(line, rdts=(), normal_form=True, if len(rdts) > 0: mng_script = _build_rdt_script(mng._sequence_name, rdts, columns) else: - # If start/end -> range, if only start: cycle - twiss - cycle back range_str = '' if start is not None and end is not None: @@ -175,6 +229,7 @@ def _tw_ng(line, rdts=(), normal_form=True, '''} ''' + send_cmd) + mng.send(mng_script) out = mng.recv('columns') @@ -190,8 +245,10 @@ def _tw_ng(line, rdts=(), normal_form=True, xs_tw_kwargs = { NG_XS_MAP.get(k, k): v for k, v in tw_kwargs.items() } + tw = line.twiss(method='4d', reverse=False, **xs_tw_kwargs) - else: + + if not xsuite_tw: # Handle wrap-around range if i_start > i_end: name_co = np.array(names[i_start:] + names[:i_end + 1] + ('_end_point',)) @@ -237,7 +294,7 @@ def _process_data(data): for nn in rdts: tw[nn] = np.atleast_1d(np.squeeze(out_dct[nn]))[:-1] - if start is None or end is None: + if compute_chromatic_properties: temp_x = tw.wx_ng * np.exp(1j*2*np.pi*tw.phix_ng) tw['ax_ng'] = np.imag(temp_x) tw['bx_ng'] = np.real(temp_x) @@ -310,15 +367,18 @@ def madng_get_init(line, at): if not hasattr(line.tracker, '_madng'): line.build_madng_model() mng = line.tracker._madng + if at == xt.START: + at = "1" + else: + at = f"'{at}'" mng.send(f""" local observed in MAD.element.flags - {mng._sequence_name}:select(observed, {{list = {{'{at}'}}}}) + {mng._sequence_name}:select(observed, {{list = {{{at}}}}}) twpart, mf = twiss {{sequence = {mng._sequence_name}, observe = 1, savemap = true, info = 2}} - {mng._sequence_name}.X0 = twpart['{at}'].__map - {mng._sequence_name}.X0.status = "Aset" ! Bug corrected in next version + {XSUITE_MADNG_ENV_NAME}.X0 = twpart[{at}].__map """) - return f"{mng._sequence_name}.X0" + return f"{XSUITE_MADNG_ENV_NAME}.X0" def _survey_ng(line): if not hasattr(line.tracker, '_madng'): @@ -390,11 +450,507 @@ def prepare(self, force=False): if init is not None and start is not None and end is not None: assert isinstance(init, xt.TwissTable) self.X0 = madng_get_init(self.line, at=start) + elif init is not None: + assert isinstance(init, xt.TwissTable) + self.X0 = madng_get_init(self.line, at=xt.START) + self._alredy_prepared = True def run(self): return self.line.madng_twiss(xsuite_tw = False, X0=self.X0, **self.tw_kwargs) +class ActionTwissMadngTPSA(Action): + def __init__(self, line, vary_names, targets = [], tw_kwargs={}, sum_rmat_tar=0, **kwargs): + self.line = line + self.vary_names = vary_names + self.targets = targets + self.optics_target_locations = None + self.optics_target_quantities = None + self.tw_kwargs = tw_kwargs + self.tw_kwargs.update(kwargs) + self.twiss_flag = None + self._already_prepared = False + self.match_rmat = False + self.match_opt = False + self.sum_rmat_tar = sum_rmat_tar + self.rmat_start_end_list = None + self.rmat_tags = None + self._last_res = None + self._needs_zeta_scale = [] + self._needs_delta_scale = [] + + def prepare(self, force=False): + """ + Prepare the MAD-NG TPSA matching environment. + This method sets up the MAD-NG environment for TPSA matching by + configuring the initial conditions, setting target locations, and quantities + based on the provided targets. + To achieve that, arrays and maps are created within MAD-NG to keep track of + the target locations, quantities and differential algebraic maps. + + Parameters + ---------- + force : bool, optional + If True, forces re-preparation even if already prepared. Default is False. + + Raises + ------ + ValueError + If the target quantity is not allowed with TPSA matching + or if start and end are provided without initial conditions. + """ + + if self._already_prepared and not force: + return + + # ------------------------------------------------------------ + # 1. COLLECT INITIAL CONDITIONS + # ------------------------------------------------------------ + + init = self.tw_kwargs.get('init', None) + + if init is None: + init = self.line.madng_twiss(**self.tw_kwargs) + self.tw_kwargs.update({'init': init}) + + assert isinstance(init, xt.TwissTable) + + if not hasattr(self.line.tracker, '_madng'): + self.line.build_madng_model() + + self.mng = self.line.tracker._madng + + self.twiss_flag = any(isinstance(tar, (xt.TargetRelPhaseAdvance)) for tar in self.targets) + + # Collect initial coordinates + coord_assign_str = self._initialize_coordinates_str(init) + + # Process targets + + targets_map_str, xs_ng_target_map = self._process_targets(init) + + + # Lua script assembly + observables = [loc for loc in self.optics_target_locations] + + param_list_str = _lua_list(self.vary_names) + observables_str = _lua_list(observables) + optics_qty_str = _lua_list(list(self.optics_target_quantities)) + + init_cond_str = self._beta_block_str(init) + + rmat_array_str = '' + if self.match_rmat: + rmat_array_str = f"{XSUITE_MADNG_ENV_NAME}.rmat_map_arr = table.new({self.sum_rmat_tar}, 0)\n" + + mng_init_str = r''' + ''' + XSUITE_MADNG_ENV_NAME + r''' = {} -- to avoid variable name clashes + local obs_flag = MAD.element.flags.observed + + local pts=''' + observables_str + r''' + + ''' + self.mng._sequence_name + r''':select(obs_flag, {list=pts}) + + local params = ''' + param_list_str + r''' + + local X0 = MAD.damap { + nv=6, -- number of variables + mo=2, -- max order of variables + np=#params, -- number of parameters + po=1, -- max order of parameters + pn=params, -- parameter names + } + + -- Converting to TPSA (mutating type) + for _, v in ipairs(params) do + MADX[v] = MADX[v] + X0[v] + end + + ''' + init_cond_str + r''' + + local map1 = MAD.gphys.bet2map(B0, X0:copy()) + + ''' + coord_assign_str + r''' + + -- Maps target locations to damaps + -- e.g. { 'BPM1' = damap1, 'BPM2' = damap2, ... } + ''' + XSUITE_MADNG_ENV_NAME + r'''.target_loc_map = table.new(0, ''' + str(len(self.optics_target_locations)) + r''') + + -- Maps rmat tags to rmat damaps + -- e.g. {rmat_damap1, rmat_damap2, ... } + ''' + rmat_array_str + r''' + + -- Array of targets with additional info (location, quantity, orbit/optical function) + -- e.g. { { loc = 'BPM1', qty = 'beta11', optfun = true }, { loc = 'BPM2', qty = 'x', orbit = 1 }, ... } + ''' + XSUITE_MADNG_ENV_NAME + r'''.targets_arr = table.new(''' + str(len(self.targets)) + r''', 0) + + -- List (Array) of target optics/orbit quantities (suitable for MAD-NG), + -- e.g. {'beta11', 'alfa11', ...} + ''' + XSUITE_MADNG_ENV_NAME + r'''.tar_optics_qtys = ''' + optics_qty_str + r''' + + -- Initial map for tracking/twiss + ''' + XSUITE_MADNG_ENV_NAME + r'''.init_X0_map = map1 + + -- Identity map + ''' + XSUITE_MADNG_ENV_NAME + r'''.empty_X0 = X0 + + -- Defining targets array + ''' + targets_map_str + r''' + + -- Mapping from xsuite quantity names to madng quantity names + ''' + xs_ng_target_map + r''' + ''' + + + self.mng.send(mng_init_str) + + self._already_prepared = True + + def _process_targets(self, init): + self.optics_target_locations = set() + self.optics_target_quantities = set() + self.rmat_start_end_list = [None] * self.sum_rmat_tar + self.rmat_tags = [] + + start = self.tw_kwargs.get('start', None) + end = self.tw_kwargs.get('end', None) + + targets_map_str = '' + xs_ng_target_map = XSUITE_MADNG_ENV_NAME + '.xs_ng_target_map = {}\n' + + for i, target in enumerate(self.targets): + if isinstance(target.tar, tuple): + self.match_opt = True + qty_orig = target.tar[0] + loc = target.tar[1] + + qty = qty_orig[:-3] if qty_orig.endswith('_ng') else XS_NG_MAP[qty_orig] + assert qty in TPSA_ALLOWED_TARGETS, f"Target quantity '{qty_orig}' not allowed with TPSA matching." + + self.optics_target_locations.add(loc) + + aux = '' + if qty in OPTFUN_QUANTITIES: + aux = 'optfun = true' + elif qty in ['x', 'px', 'y', 'py', 't', 'pt']: + aux = f'orbit = {['x', 'px', 'y', 'py', 't', 'pt'].index(qty) + 1}' + + if qty_orig == 'zeta': + self._needs_zeta_scale.append(i) + elif qty_orig == 'delta': + self._needs_delta_scale.append(i) + + # set string for quantity mapping + loc to save in madng + targets_map_str += f"{XSUITE_MADNG_ENV_NAME}.targets_arr[{i+1}] = {{ loc = '{loc}', qty = '{qty}', {aux} }}\n" + xs_ng_target_map += f"{XSUITE_MADNG_ENV_NAME}.xs_ng_target_map['{qty_orig}'] = '{qty}'\n" + + self.optics_target_quantities.add(qty_orig) + + + elif hasattr(target, "start") and hasattr(target, "end"): + if target.start != "__ele_start__": + loc_start = target.start + elif start is not None: + loc_start = start + else: + loc_start = init.name[0] + if target.end != "__ele_stop__": + loc_end = target.end + elif end is not None: + loc_end = end + else: + loc_end = init.name[-2] + + if isinstance(target, xt.TargetRelPhaseAdvance): + self.match_opt = True + qty_orig = target.var + qty = qty_orig[:-3] if qty_orig.endswith('_ng') else XS_NG_MAP[qty_orig] + + assert qty in TPSA_ALLOWED_TARGETS, f"Target quantity '{target.var}' not allowed with TPSA matching." + + self.optics_target_locations.add(loc_start) + self.optics_target_locations.add(loc_end) + self.optics_target_quantities.add(qty_orig) + + targets_map_str += f"{XSUITE_MADNG_ENV_NAME}.targets_arr[{i+1}] = {{ loc = '{loc_end}', qty = '{qty}', loc_start = '{loc_start}', optfun = true }}\n" + xs_ng_target_map += f"{XSUITE_MADNG_ENV_NAME}.xs_ng_target_map['{target.var}'] = '{qty}'\n" + + elif isinstance(target, xt.TargetRmatrixTerm): + self.match_rmat = True + qty = target.term + tag = target.tag.split('_')[0] + idx = int(tag) + + self.rmat_start_end_list[idx] = (loc_start, loc_end) + self.rmat_tags.append(target.tag) + + targets_map_str += f"{XSUITE_MADNG_ENV_NAME}.targets_arr[{i+1}] = {{ loc = '{loc_end}', qty = '{qty}', loc_start = '{loc_start}', rmat = true, tag = {tag} }}\n" + xs_ng_target_map += f"{XSUITE_MADNG_ENV_NAME}.xs_ng_target_map['{target.term}'] = '{qty}'\n" + + else: + raise NotImplementedError(f"Target of type {type(target)} not implemented for MAD-NG TPSA matching.") + + self.optics_target_locations = list(self.optics_target_locations) + + return targets_map_str, xs_ng_target_map + + def _beta_block_str(self, init): + madng_init_flag = "x_ng" in init.cols + quantity_appendix = "_ng" if "x_ng" in init.cols else "" + start = self.tw_kwargs.get('start', None) + start_loc = 0 if start is None else start + + init_cond_str = f"""local B0 = MAD.beta0 {{ + beta11 = {init['beta11'+quantity_appendix,start_loc] if madng_init_flag else init['betx',start_loc]}, + beta22 = {init['beta22'+quantity_appendix,start_loc] if madng_init_flag else init['bety',start_loc]}, + alfa11 = {init['alfa11'+quantity_appendix,start_loc] if madng_init_flag else init['alfx',start_loc]}, + alfa22 = {init['alfa22'+quantity_appendix,start_loc] if madng_init_flag else init['alfy',start_loc]}, + dx = {init['dx'+quantity_appendix,start_loc] if madng_init_flag else init['dx',start_loc]}, + dpx = {init['dpx'+quantity_appendix,start_loc] if madng_init_flag else init['dpx',start_loc]}, + dy = {init['dy'+quantity_appendix,start_loc] if madng_init_flag else init['dy',start_loc]}, + dpy = {init['dpy'+quantity_appendix,start_loc] if madng_init_flag else init['dpy',start_loc]}, + }}""" + + return init_cond_str + + def _initialize_coordinates_str(self, init): + madng_init_flag = "x_ng" in init.cols + quantity_appendix = "_ng" if madng_init_flag else "" + beta0 = self.line.particle_ref.beta0[0] + start = self.tw_kwargs.get('start', None) + start_loc = 0 if start is None else start + + init_coord = np.zeros(6) + init_coord[0] = init['x' + quantity_appendix, start_loc] + init_coord[1] = init['px' + quantity_appendix, start_loc] + init_coord[2] = init['y' + quantity_appendix, start_loc] + init_coord[3] = init['py' + quantity_appendix, start_loc] + + if madng_init_flag: + init_coord[4] = init['t_ng', start_loc] + init_coord[5] = init['pt_ng', start_loc] + else: + init_coord[4] = init['zeta', start_loc] / beta0 + init_coord[5] = init['ptau', start_loc] # ptau corresponds to pt + + # Build small Lua snippet: map1.x = val ... + coord_assign = " ".join( + f"map1.{p} = {v}" + for p, v in zip(['x','px','y','py','t','pt'], init_coord) + if abs(v) > 1e-12 + ) + return coord_assign + + def run(self): + """ + Execute the MAD-NG TPSA matching action. + This method performs either a Twiss or Track operation in MAD-NG + depending if quantities can be calculated through tracking or not. + It retrieves the results and constructs a TwissTable with the requested + target quantities at the specified target locations. + + Returns + ------- + xt.TwissTable + A TwissTable containing the results of the Twiss or Track operation + with the requested target quantities at the specified target locations. + """ + + if self._already_prepared is False: + self.prepare() + + start = self.tw_kwargs.get('start', None) + end = self.tw_kwargs.get('end', None) + + operation = "twiss" if self.twiss_flag else "track" + range_str = f"range='{start}/{end}', " if (start and end) else "" + + mng_track_str = ( + f"local trk, mflw = MAD.{operation}{{\n" + f" sequence={self.mng._sequence_name},\n" + f" X0={XSUITE_MADNG_ENV_NAME}.init_X0_map,\n" + f" savemap=true,\n" + f" observe=1,\n" + f" {range_str}\n" + f"}}\n" + f"{XSUITE_MADNG_ENV_NAME}.trk = trk\n" + ) + + self.mng.send(mng_track_str) + + loc_map_str = '' + loc_map_str = "\n".join(f"{XSUITE_MADNG_ENV_NAME}.target_loc_map['{loc}'] = {XSUITE_MADNG_ENV_NAME}.trk['{loc}'].__map" for loc in self.optics_target_locations) + + # Twiss + if self.twiss_flag: + mng_table_str = r''' + local trk = ''' + XSUITE_MADNG_ENV_NAME + r'''.trk + ''' + loc_map_str + r''' + py:send(trk) + ''' + + res = self.mng.send(mng_table_str).recv(XSUITE_MADNG_ENV_NAME + '.trk').to_df() + res = xt.TwissTable(res) + + # Add quantities which are not present yet with the name corresponding to the target quantity + for qty in self.optics_target_quantities: + if qty not in res.cols: + res[qty] = res[qty[:-3] if qty.endswith('_ng') else XS_NG_MAP[qty]] + + # Track + else: + mng_table_str = r''' + local trk = ''' + XSUITE_MADNG_ENV_NAME + r'''.trk + -- Add derived columns which are not present due to Track calculation + -- and use target names as defined from the user (Xsuite) + for _, tar in ipairs( ''' + XSUITE_MADNG_ENV_NAME + r'''.tar_optics_qtys ) do + if not trk[''' + XSUITE_MADNG_ENV_NAME + r'''.xs_ng_target_map[tar]] then + trk:addcol(tar, \ri -> MAD.gphys.optfun(trk[ri].__map, ''' + XSUITE_MADNG_ENV_NAME + r'''.xs_ng_target_map[tar] .. '_')) + end + end + + -- Save damaps + ''' + loc_map_str + r''' + py:send(trk) + ''' + + res_ng = self.mng.send(mng_table_str).recv(XSUITE_MADNG_ENV_NAME + '.trk') + res_tab = res_ng.to_df() + res = xt.TwissTable(res_tab) + + if 'zeta' in self.optics_target_quantities: + res._data.loc[:, 'zeta'] = res['t'] * self.line.particle_ref.beta0[0] + if 'delta' in self.optics_target_quantities: + res._data.loc[:, 'delta'] = ptau2delta(res['pt'], self.line.particle_ref.beta0[0]) + + if self.match_rmat: + res = self.handle_rmatrices(res) + + self._last_res = res + return res + + def handle_rmatrices(self, res): + rmat_str = '' + rmatrices = [] + for i in range(self.sum_rmat_tar): + start_rmat = self.rmat_start_end_list[i][0] + end_rmat = self.rmat_start_end_list[i][1] + if start_rmat == '__ele_start__': + start_rmat = self.tw_kwargs.get('start', None) + if end_rmat == '__ele_stop__': + end_rmat = self.tw_kwargs.get('end', None) + + range_str = '' + if start_rmat is not None and end_rmat is not None: + range_str = f"range = '{start_rmat}/{end_rmat}', " + rmat_str = ( + f"local trkid, mflwid = MAD.track{{\n" + f" sequence={self.mng._sequence_name},\n" + f" X0={XSUITE_MADNG_ENV_NAME}.empty_X0,\n" + f" savemap=true,\n" + f" observe=1,\n" + f" {range_str}\n" + f"}}\n" + f"{XSUITE_MADNG_ENV_NAME}.rmat_map_arr[{i}] = mflwid[1]\n" + f"local rmat = mflwid[1]:get1()\n" + f"py:send(rmat)\n" + ) + + rmat_res = self.mng.send(rmat_str).recv('rmat') + rmatrices.append(rmat_res) + + for tag in self.rmat_tags: + t0, term = tag.split("_") + ii = int(term[1]) - 1 + jj = int(term[2]) - 1 + res._data.attrs[tag] = rmatrices[int(t0)][ii, jj] + + return res + + def acquire_jacobian(self): + ''' + Acquire the Jacobian matrix for the TPSA matching targets and variables. + This method computes the Jacobian matrix for the specified targets and + variables using MAD-NG's TPSA capabilities. It constructs + the Jacobian matrix by evaluating the sensitivity of each target quantity + with respect to each variable using MAD-NG's optfun function (optical functions) + or by direct extraction from the TPSA (orbit). + + Returns + ------- + np.ndarray + A 2D NumPy array representing the Jacobian matrix, where each row + corresponds to a target and each column corresponds to a variable. + ''' + + tar_len_str = f"local tarlen = {len(self.targets)}\n" + vary_len_str = f"local varylen = {len(self.vary_names)}\n" + jac_decl_str = f"{XSUITE_MADNG_ENV_NAME}.jac = MAD.matrix(tarlen, varylen)\n" + + mng_str = tar_len_str + vary_len_str + jac_decl_str + r''' + -- Compute Jacobian + for i, target in ipairs( ''' + XSUITE_MADNG_ENV_NAME + r'''.targets_arr ) do + local map = nil + local nv = ''' + XSUITE_MADNG_ENV_NAME + r'''.init_X0_map:nv() + if target.optfun or target.orbit then + map = ''' + XSUITE_MADNG_ENV_NAME + r'''.target_loc_map[target.loc] + elseif target.rmat then + map = ''' + XSUITE_MADNG_ENV_NAME + r'''.rmat_map_arr[target.tag] + end + + local monom = MAD.monomial(nv + varylen) -- BUILD MONOMIAL + for j = 1, map.np(map), 1 do + -- Quantity which can be calculated with optfun + if target.optfun then + -- If loc_start (phase advance) is defined, we provide initial map + if target.loc_start then + local a0 = ''' + XSUITE_MADNG_ENV_NAME + r'''.target_loc_map[target.loc_start] + ''' + XSUITE_MADNG_ENV_NAME + r'''.jac[(i-1)*varylen + j] = MAD.gphys.optfun(map, target.qty .. "_", j, 1, a0) + else + ''' + XSUITE_MADNG_ENV_NAME + r'''.jac[(i-1)*varylen + j] = MAD.gphys.optfun(map, target.qty .. "_", j, 1) + end + + -- Orbit Quantity + elseif target.orbit then + monom[nv + j] = 1 + ''' + XSUITE_MADNG_ENV_NAME + r'''.jac[(i-1)*varylen + j] = map[target.orbit]:get(monom) + monom[nv + j] = 0 + + elseif target.rmat then + -- rmatrix terms are extracted directly from the damap + local ind_1 = tonumber(target.qty:sub(2,2)) + local ind_2 = tonumber(target.qty:sub(3,3)) + monom[nv + j] = 1 + monom[ind_2] = 1 + ''' + XSUITE_MADNG_ENV_NAME + r'''.jac[(i-1)*varylen + j] = map[ind_1]:get(monom) + monom[nv + j] = 0 + monom[ind_2] = 0 + end + end + end + + py:send(''' + XSUITE_MADNG_ENV_NAME + r'''.jac) + ''' + + self.mng.send(mng_str) + + jac = np.array(self.mng.recv()) + + for i in self._needs_zeta_scale: + jac[i, :] *= self.line.particle_ref.beta0[0] + for i in self._needs_delta_scale: + jac[i, :] *= dptau2ddelta(self._last_res['delta', self.targets[i].tar[1]], self.line.particle_ref.beta0[0]) + return jac + + def cleanup(self): + # Need to reconvert TPSAs to normal values + if self._already_prepared is True: + mng_str = '' + for var_name in self.vary_names: + mng_str += f"MADX['{var_name}'] = MADX['{var_name}']:get0()\n" + mng_str += f"{XSUITE_MADNG_ENV_NAME}.X0 = nil\n" + self.mng.send(mng_str) + self._already_prepared = False def line_to_madng(line, sequence_name='seq', temp_fname=None, keep_files=False, **kwargs): @@ -410,11 +966,16 @@ def line_to_madng(line, sequence_name='seq', temp_fname=None, keep_files=False, from pymadng import MAD + nocharge = str(kwargs.pop('nocharge', True)).lower() + mng = MAD(**kwargs) mng.send(f""" local mad_func = loadfile('{temp_fname}.mad', nil, MADX) assert(mad_func) mad_func() + MAD.option.nocharge = {nocharge} + MADX.option.rbarc = true + {XSUITE_MADNG_ENV_NAME} = {{}} -- to avoid variable name clashes """) mng._init_madx_data = madx_seq diff --git a/xtrack/match.py b/xtrack/match.py index 040983844..6afc0eb8b 100644 --- a/xtrack/match.py +++ b/xtrack/match.py @@ -36,9 +36,11 @@ 'dpy': 100., 'dy_ng' : 10., 'dpy_ng': 100., + 't_ng': 10., + 'pt_ng': 100., } -ALLOWED_TARGET_KWARGS= ['x', 'px', 'y', 'py', 'zeta', 'delta', 'pzata', 'ptau', +ALLOWED_TARGET_KWARGS= ['x', 'px', 'y', 'py', 'zeta', 'delta', 'pzeta', 'ptau', 'betx', 'bety', 'alfx', 'alfy', 'gamx', 'gamy', 'mux', 'muy', 'dx', 'dpx', 'dy', 'dpy', 'qx', 'qy', 'dqx', 'dqy', @@ -54,7 +56,8 @@ 'c_minus_re_0', 'c_minus_im_0', 'c_minus_re', 'c_minus_im', 'beta11_ng', 'beta22_ng', 'alfa11_ng', 'alfa22_ng', - 'dx_ng', 'dpx_ng'] + 'dx_ng', 'dpx_ng', 'dy_ng', 'dpy_ng', + 'x_ng', 'px_ng', 'y_ng', 'py_ng', 't_ng', 'pt_ng',] # Alternative transitions functions @@ -535,7 +538,6 @@ def __repr__(self): return f'TargetPhaseAdv({self.var}({self.end} - {self.start}), val={self.value}, tol={self.tol}, weight={self.weight})' def compute(self, tw): - if self.end == '__ele_stop__': mu_1 = tw[self.var, -1] else: @@ -548,7 +550,6 @@ def compute(self, tw): return mu_1 - mu_0 - class TargetRmatrixTerm(Target): def __init__(self, tar, value, start=None, end=None, tag='', **kwargs): @@ -607,6 +608,9 @@ def compute(self, tw): 'Only terms of the R-matrix in the form "r11", "r12", "r21", "r22", etc' ' are supported') + if hasattr(tw._data, 'attrs') and self.tag in tw._data.attrs: + return tw._data.attrs[self.tag] + if self.start is xt.START: self.start = tw.name[0] @@ -812,6 +816,52 @@ def run(self, allow_failure=True): out.line = self.line return out +class MeritFunctionLine(xd.MeritFunctionForMatch): + def __init__( + self, + merit_function_match, + use_tpsa=False, + ): + + self.vary = merit_function_match.vary + self.targets = merit_function_match.targets + self.actions = merit_function_match.actions + self.return_scalar = merit_function_match.return_scalar + self.call_counter = merit_function_match.call_counter + self.verbose = merit_function_match.verbose + self.tw_kwargs = merit_function_match.tw_kwargs + self.steps_for_jacobian = merit_function_match.steps_for_jacobian + self.found_point_within_tol = merit_function_match.found_point_within_tol + self.zero_if_met = merit_function_match.zero_if_met + self.show_call_counter = merit_function_match.show_call_counter + self.check_limits = merit_function_match.check_limits + self.use_tpsa = use_tpsa + + def get_jacobian(self, x=None, f0=None): + if self.use_tpsa: + return self.get_jacobian_tpsa() + else: + return super().get_jacobian(x, f0=f0) + + def get_jacobian_tpsa(self): + from .madng_interface import ActionTwissMadngTPSA + action = None + for a in self.actions: + if isinstance(a, ActionTwissMadngTPSA): + action = a + break + if action is None: + raise RuntimeError('No ActionTwissMadngTPSA found in actions for TPSA jacobian computation') + + + + jacobian = action.acquire_jacobian() + + for i, tar in enumerate(self.targets): + jacobian[i] *= tar.weight + + return jacobian + class OptimizeLine(xd.Optimize): def __init__(self, line, vary, targets, assert_within_tol=True, @@ -821,7 +871,7 @@ def __init__(self, line, vary, targets, assert_within_tol=True, n_steps_max=20, default_tol=None, solver=None, check_limits=True, action_twiss=None, action_twiss_ng=None, - name="", + use_tpsa=False, name="", **kwargs): if hasattr(targets, 'values'): # dict like @@ -830,7 +880,28 @@ def __init__(self, line, vary, targets, assert_within_tol=True, if not isinstance(targets, (list, tuple)): targets = [targets] + # Flatten targets and assign tags for TargetRMatrixTerms if multiple targets_flatten = [] + start_end_tuple_set = list() + rmat_index = 0 + if any(isinstance(tt, (TargetRmatrixTerm, TargetRmatrix)) for tt in targets): + for tt in targets: + if isinstance(tt, TargetRmatrixTerm): + if (tt.start, tt.end) in start_end_tuple_set: + rmat_index = start_end_tuple_set.index((tt.start, tt.end)) + else: + rmat_index = len(start_end_tuple_set) + start_end_tuple_set.append((tt.start, tt.end)) + tt.tag = f'{rmat_index}_{tt.term}' + elif isinstance(tt, TargetRmatrix): + if (tt.targets[0].start, tt.targets[0].end) in start_end_tuple_set: + rmat_index = start_end_tuple_set.index((tt.targets[0].start, tt.targets[0].end)) + else: + rmat_index = len(start_end_tuple_set) + start_end_tuple_set.append((tt.targets[0].start, tt.targets[0].end)) + for sub_tt in tt.targets: + sub_tt.tag = f'{rmat_index}_{sub_tt.term}' + for tt in targets: if isinstance(tt, xd.TargetList): for tt1 in tt.targets: @@ -840,16 +911,35 @@ def __init__(self, line, vary, targets, assert_within_tol=True, aux_vary = [] + # part of the `auxvar` experimental code + # if isinstance(tt.value, (GreaterThan, LessThan)): + # if tt.value.mode == 'auxvar': + # aux_vary.append(tt.value.gen_vary(aux_vary_container)) + # aux_vary_container[aux_vary[-1].name] = 0 + # val = tt.runeval() + # if val > 0: + # aux_vary_container[aux_vary[-1].name] = np.sqrt(val) + + if not isinstance(vary, (list, tuple)): + vary = [vary] + + vary = list(vary) + aux_vary + + vary_flatten = _flatten_vary(vary) + _complete_vary_with_info_from_line(vary_flatten, line) + for tt in targets_flatten: # Handle action if tt.action is None: - if (isinstance(tt.tar, tuple) and tt.tar[0].endswith('_ng')) or ( - isinstance(tt, TargetRelPhaseAdvance) and tt.var.endswith('_ng')): + if use_tpsa: if action_twiss_ng is None: - from .madng_interface import ActionTwissMadng - action_twiss_ng = ActionTwissMadng( - line, {}, **kwargs) + from .madng_interface import ActionTwissMadngTPSA + + action_twiss_ng = ActionTwissMadngTPSA( + line, [v.name for v in vary_flatten], targets_flatten, {}, + sum_rmat_tar=len(start_end_tuple_set), **kwargs + ) action_twiss_ng.prepare() tt.action = action_twiss_ng else: @@ -906,22 +996,6 @@ def __init__(self, line, vary, targets, assert_within_tol=True, else: tt.tol = default_tol - # part of the `auxvar` experimental code - # if isinstance(tt.value, (GreaterThan, LessThan)): - # if tt.value.mode == 'auxvar': - # aux_vary.append(tt.value.gen_vary(aux_vary_container)) - # aux_vary_container[aux_vary[-1].name] = 0 - # val = tt.runeval() - # if val > 0: - # aux_vary_container[aux_vary[-1].name] = np.sqrt(val) - - if not isinstance(vary, (list, tuple)): - vary = [vary] - - vary = list(vary) + aux_vary - - vary_flatten = _flatten_vary(vary) - _complete_vary_with_info_from_line(vary_flatten, line) xd.Optimize.__init__(self, vary=vary_flatten, targets=targets_flatten, solver=solver, @@ -931,9 +1005,12 @@ def __init__(self, line, vary, targets, assert_within_tol=True, restore_if_fail=restore_if_fail, check_limits=check_limits, name=name) + + _err = MeritFunctionLine(self._err, use_tpsa=use_tpsa) self.line = line self.action_twiss = action_twiss self.default_tol = default_tol + self._err = _err def clone(self, add_targets=None, add_vary=None, remove_targets=None, remove_vary=None, @@ -991,6 +1068,40 @@ def clone(self, add_targets=None, add_vary=None, def plot(self, *args, **kwargs): return self.action_twiss.run().plot(*args, **kwargs) + def step( + self, + n_steps=1, + take_best=True, + enable_target=None, + enable_vary=None, + enable_vary_name=None, + disable_target=None, + disable_vary=None, + disable_vary_name=None, + rcond=None, + sing_val_cutoff=None, + verbose=None, + broyden=False, + cleanup_madng_tpsa=False, + ): + super().step(n_steps, take_best, enable_target, enable_vary, enable_vary_name, disable_target, + disable_vary, disable_vary_name, rcond, sing_val_cutoff, verbose, broyden) + + if cleanup_madng_tpsa and self._err.use_tpsa: + for a in self.actions: + if hasattr(a, "cleanup"): + a.cleanup() + break + + def solve(self, n_steps=None, verbose=None, take_best=True, rcond=None, sing_val_cutoff=None, broyden=False, cleanup_madng_tpsa=True): + super().solve(n_steps, verbose, take_best, rcond, sing_val_cutoff, broyden) + + if cleanup_madng_tpsa and self._err.use_tpsa: + for a in self.actions: + if hasattr(a, "cleanup"): + a.cleanup() + break + def _flatten_vary(vary): vary_flatten = [] for vv in vary: diff --git a/xtrack/particles/particles.py b/xtrack/particles/particles.py index 5dfb1aca1..d64ef6e95 100644 --- a/xtrack/particles/particles.py +++ b/xtrack/particles/particles.py @@ -1277,7 +1277,7 @@ def gamma0(self): def gamma0(self, value): self.gamma0[:] = value - + def update_beta0(self, new_beta0): @@ -2305,4 +2305,42 @@ def _update_kwargs0_from_pdg_id(pdg_id, kwargs): q, _, _, _ = get_properties_from_pdg_id(pdg_id) kwargs['q0'] = q if mass0 is None: - kwargs['mass0'] = get_mass_from_pdg_id(pdg_id) \ No newline at end of file + kwargs['mass0'] = get_mass_from_pdg_id(pdg_id) + +def ptau2delta(ptau, beta0): + """Convert transverse momentum pt/p to relative momentum deviation dp/p. + + Parameters + ---------- + ptau : float + Transverse momentum relative to total momentum (pt/p, dimensionless). + beta0 : float + Particle relativistic beta (v/c). + + Returns + ------- + float + Relative momentum deviation (dp/p, dimensionless). + """ + + _beta0 = 1 / beta0 + return np.sqrt(1 + 2*ptau*_beta0 + ptau**2) - 1 + +def dptau2ddelta(ptau, beta0): + """Calculate derivative of relative momentum deviation dp/p with respect to pt. + + Parameters + ---------- + ptau : float + Transverse momentum relative to total momentum (pt/p, dimensionless). + beta0 : float + Particle relativistic beta (v/c). + + Returns + ------- + float + Derivative of relative momentum deviation (d(dp/p), dimensionless). + """ + + _beta0 = 1 / beta0 + return (_beta0 + ptau) / np.sqrt(1 + 2*ptau*_beta0 + ptau**2)