diff --git a/setup.cfg b/setup.cfg index 1f0b48c..3de096b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,8 +28,8 @@ install_requires = click>=8.0 pandas>=1.1.0 rxn-chem-utils>=1.3.0 - rxn-onmt-models>=1.0.0 - rxn-onmt-utils>=1.0.0 + rxn-onmt-models@git+https://github.com/rxn4chemistry/rxn-onmt-models.git@d14ee743fe0b4e1881613113f33304b2b91887c8 + rxn-onmt-utils@git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@100ed8c721fcc026c603c11998a8478b13630ce8 rxn-utils>=1.1.9 [options.packages.find] diff --git a/src/rxn/metrics/run_metrics.py b/src/rxn/metrics/run_metrics.py index 769bce6..abea38e 100644 --- a/src/rxn/metrics/run_metrics.py +++ b/src/rxn/metrics/run_metrics.py @@ -68,6 +68,8 @@ def run_model_for_metrics( batch_size: int, gpu: bool, initialize_logger: bool = False, + as_external_command: bool = False, + **kwargs, ) -> None: ensure_directory_exists_and_is_empty(output_dir) files = get_metrics_files(task, output_dir) @@ -88,6 +90,8 @@ def run_model_for_metrics( beam_size=beam_size, batch_size=batch_size, gpu=gpu, + as_external_command=as_external_command, + **kwargs, ) canonicalize_file( diff --git a/src/rxn/metrics/scripts/prepare_forward_metrics.py b/src/rxn/metrics/scripts/prepare_forward_metrics.py index 9f63d1a..8dc3229 100644 --- a/src/rxn/metrics/scripts/prepare_forward_metrics.py +++ b/src/rxn/metrics/scripts/prepare_forward_metrics.py @@ -38,6 +38,10 @@ @click.option( "--no_metrics", is_flag=True, help="If given, the metrics will not be computed." ) +@click.option( + "--as_external_command", type=bool, default=False, help="Run translation as external ONMT command" +) +@click.argument('extra_options', nargs=-1, type=click.UNPROCESSED) def main( precursors_file: Path, products_file: Path, @@ -47,10 +51,15 @@ def main( n_best: int, gpu: bool, no_metrics: bool, + as_external_command: bool, + extra_options: list, ) -> None: """Starting from the ground truth files and forward model, generate the translation files needed for the metrics, and calculate the default metrics.""" + # Convert extra_options into a dictionary + kwargs = {key: value for key, value in (opt.split('=') for opt in extra_options)} + run_model_for_metrics( task="forward", model_path=forward_model, @@ -62,6 +71,8 @@ def main( batch_size=batch_size, gpu=gpu, initialize_logger=True, + as_external_command=as_external_command, + **kwargs, ) if not no_metrics: diff --git a/src/rxn/metrics/scripts/prepare_retro_metrics.py b/src/rxn/metrics/scripts/prepare_retro_metrics.py index 6652abd..fe0c67b 100644 --- a/src/rxn/metrics/scripts/prepare_retro_metrics.py +++ b/src/rxn/metrics/scripts/prepare_retro_metrics.py @@ -93,6 +93,10 @@ "only if the true reactant accuracy is activated." ), ) +@click.option( + "--as_external_command", type=bool, default=False, help="Run translation as external ONMT command" +) +@click.argument('extra_options', nargs=-1, type=click.UNPROCESSED) def main( precursors_file: Path, products_file: Path, @@ -108,10 +112,16 @@ def main( class_tokens: Optional[int], with_true_reactant_accuracy: bool, rxnmapper_batch_size: int, + as_external_command: bool, + extra_options: list, ) -> None: """Starting from the ground truth files and two models (retro, forward), generate the translation files needed for the metrics, and calculate the default metrics. """ + + # Convert extra_options into a dictionary + kwargs = {key: value for key, value in (opt.split('=') for opt in extra_options)} + true_reactant_environment_check(with_true_reactant_accuracy) ensure_directory_exists_and_is_empty(output_dir) @@ -142,6 +152,8 @@ def main( beam_size=beam_size, batch_size=batch_size, gpu=gpu, + as_external_command=as_external_command, + **kwargs, ) canonicalize_file( @@ -161,6 +173,8 @@ def main( beam_size=10, batch_size=batch_size, gpu=gpu, + as_external_command=as_external_command, + **kwargs, ) canonicalize_file(