diff --git a/run.py b/run.py index 21d6585..33131a7 100644 --- a/run.py +++ b/run.py @@ -10,7 +10,7 @@ import pprint FILEPATH= pathlib.Path(yaml_directory.__file__).parent.resolve() #FP= pathlib.Path(__file__).parent.resolve() -def run_simulation_server(name, monitor_overview, measurement_noise_level): +def run_simulation_server(name, monitor_overview, measurement_noise_level, number_of_particles): if name == "diag0": devices = load_relevant_controls( os.path.join( FILEPATH, "DIAG0.yaml") @@ -25,7 +25,7 @@ def run_simulation_server(name, monitor_overview, measurement_noise_level): raise ValueError(f"Unknown virtual accelerator name: {name}") PVDB = create_pvdb(devices) - va = get_virtual_accelerator(name, monitor_overview, measurement_noise_level) + va = get_virtual_accelerator(name, monitor_overview, measurement_noise_level, number_of_particles) server = SimServer(PVDB) driver = SimDriver(server=server, devices=devices, virtual_accelerator=va) @@ -54,7 +54,16 @@ def run_simulation_server(name, monitor_overview, measurement_noise_level): help="If provided, adds realistic noise to measurements. See `simulation_server.virtual_accelerator.utils.add_noise` for details.", ) + parser.add_argument( + "--number_of_particles", + type=int, + default=0, + help="If provided, limits the number of particles to this amount.", + ) + args = parser.parse_args() + + print("args.number_of_particles ", args.number_of_particles) run_simulation_server( - args.name, args.monitor_overview, args.measurement_noise_level - ) \ No newline at end of file + args.name, args.monitor_overview, args.measurement_noise_level, args.number_of_particles + ) diff --git a/simulation_server/factory.py b/simulation_server/factory.py index 0abb39a..ce86327 100644 --- a/simulation_server/factory.py +++ b/simulation_server/factory.py @@ -10,7 +10,7 @@ LCLS_LATTICE = pathlib.Path(os.environ.get("LCLS_LATTICE", "/sdf/group/ad/sw/scm/repos/optics/lcls-lattice/cheetah")) -def get_virtual_accelerator(name, monitor_overview=False, measurement_noise_level=None): +def get_virtual_accelerator(name, monitor_overview=False, measurement_noise_level=None, number_of_particles=0): """ Create an instance of VirtualAccelerator for a given beamline. @@ -26,6 +26,8 @@ def get_virtual_accelerator(name, monitor_overview=False, measurement_noise_leve measurement_noise_level: float, optional If provided, adds realistic noise to measurements. See `simulation_server.virtual_accelerator.utils.add_noise` for details. + num_particles: int, optional + If provided, limits the number of particles to this amount. Returns ------- @@ -60,6 +62,10 @@ def get_virtual_accelerator(name, monitor_overview=False, measurement_noise_leve mapping_file = os.path.join(FILEPATH, "mappings", "lcls_elements.csv") lattice_file = os.path.join(LCLS_LATTICE,"nc_hxr.json") + if number_of_particles > 0: + incoming_beam.particles = incoming_beam.particles[:number_of_particles] + incoming_beam.survival_probabilities=incoming_beam.survival_probabilities[:number_of_particles] + return VirtualAccelerator( lattice_file=lattice_file, initial_beam_distribution=incoming_beam, diff --git a/start.sh b/start.sh index 59af758..fc3082d 100755 --- a/start.sh +++ b/start.sh @@ -12,11 +12,12 @@ else export LCLS_LATTICE=$1 fi - + NAME="${2:-diag0}" OVERVIEW="${3:-False}" NOISE="${4:-0.0}" +NUM_PARTICLES="${5:-0}" # Start the server echo "Starting server..." -python3 run.py --name $NAME --monitor_overview $OVERVIEW --measurement_noise_level $NOISE \ No newline at end of file +python3 run.py --name $NAME --monitor_overview $OVERVIEW --measurement_noise_level $NOISE --number_of_particles $NUM_PARTICLES