Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ install_requires =
click>=8.0
pandas>=1.1.0
rxn-chem-utils>=1.3.0
rxn-onmt-models>=1.0.0
rxn-onmt-utils>=1.0.0
rxn-onmt-models@git+https://github.com/rxn4chemistry/rxn-onmt-models.git@d14ee743fe0b4e1881613113f33304b2b91887c8
rxn-onmt-utils@git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@100ed8c721fcc026c603c11998a8478b13630ce8
rxn-utils>=1.1.9

[options.packages.find]
Expand Down
4 changes: 4 additions & 0 deletions src/rxn/metrics/run_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def run_model_for_metrics(
batch_size: int,
gpu: bool,
initialize_logger: bool = False,
as_external_command: bool = False,
**kwargs,
) -> None:
ensure_directory_exists_and_is_empty(output_dir)
files = get_metrics_files(task, output_dir)
Expand All @@ -88,6 +90,8 @@ def run_model_for_metrics(
beam_size=beam_size,
batch_size=batch_size,
gpu=gpu,
as_external_command=as_external_command,
**kwargs,
)

canonicalize_file(
Expand Down
11 changes: 11 additions & 0 deletions src/rxn/metrics/scripts/prepare_forward_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
@click.option(
"--no_metrics", is_flag=True, help="If given, the metrics will not be computed."
)
@click.option(
"--as_external_command", type=bool, default=False, help="Run translation as external ONMT command"
)
@click.argument('extra_options', nargs=-1, type=click.UNPROCESSED)
def main(
precursors_file: Path,
products_file: Path,
Expand All @@ -47,10 +51,15 @@ def main(
n_best: int,
gpu: bool,
no_metrics: bool,
as_external_command: bool,
extra_options: list,
) -> None:
"""Starting from the ground truth files and forward model, generate the
translation files needed for the metrics, and calculate the default metrics."""

# Convert extra_options into a dictionary
kwargs = {key: value for key, value in (opt.split('=') for opt in extra_options)}

run_model_for_metrics(
task="forward",
model_path=forward_model,
Expand All @@ -62,6 +71,8 @@ def main(
batch_size=batch_size,
gpu=gpu,
initialize_logger=True,
as_external_command=as_external_command,
**kwargs,
)

if not no_metrics:
Expand Down
14 changes: 14 additions & 0 deletions src/rxn/metrics/scripts/prepare_retro_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@
"only if the true reactant accuracy is activated."
),
)
@click.option(
"--as_external_command", type=bool, default=False, help="Run translation as external ONMT command"
)
@click.argument('extra_options', nargs=-1, type=click.UNPROCESSED)
def main(
precursors_file: Path,
products_file: Path,
Expand All @@ -108,10 +112,16 @@ def main(
class_tokens: Optional[int],
with_true_reactant_accuracy: bool,
rxnmapper_batch_size: int,
as_external_command: bool,
extra_options: list,
) -> None:
"""Starting from the ground truth files and two models (retro, forward),
generate the translation files needed for the metrics, and calculate the default metrics.
"""

# Convert extra_options into a dictionary
kwargs = {key: value for key, value in (opt.split('=') for opt in extra_options)}

true_reactant_environment_check(with_true_reactant_accuracy)

ensure_directory_exists_and_is_empty(output_dir)
Expand Down Expand Up @@ -142,6 +152,8 @@ def main(
beam_size=beam_size,
batch_size=batch_size,
gpu=gpu,
as_external_command=as_external_command,
**kwargs,
)

canonicalize_file(
Expand All @@ -161,6 +173,8 @@ def main(
beam_size=10,
batch_size=batch_size,
gpu=gpu,
as_external_command=as_external_command,
**kwargs,
)

canonicalize_file(
Expand Down
Loading