diff --git a/.typos.toml b/.typos.toml index 75fc8b0c..2fa346b5 100644 --- a/.typos.toml +++ b/.typos.toml @@ -50,4 +50,4 @@ preprocessor = "preprocessor" logits = "logits" [default] -locale = "en-us" +locale = "en-us" \ No newline at end of file diff --git a/research-radar/pipelines/classification.py b/research-radar/pipelines/classification.py index 743284c5..5366e302 100644 --- a/research-radar/pipelines/classification.py +++ b/research-radar/pipelines/classification.py @@ -15,9 +15,7 @@ # limitations under the License. # -""" -Pipeline for article classification and dataset processing. -""" +"""Pipeline for article classification and dataset processing.""" from typing import Dict, Optional @@ -38,8 +36,7 @@ @pipeline(enable_cache=False) def classification_pipeline(config: Optional[Dict] = None): - """ - Pipeline for article classification and dataset processing. + """Pipeline for article classification and dataset processing. Args: config: Pipeline configuration from base_config.yaml @@ -53,15 +50,15 @@ def classification_pipeline(config: Optional[Dict] = None): hf_token = get_hf_token() - pipeline_config = config.steps.classify - classification_type = pipeline_config.classification_type + classify_config = config["steps"]["classify"] + classification_type = classify_config["classification_type"] logger.log_classification_type(classification_type) dataset_path = ( - config.datasets.unclassified + config["datasets"]["unclassified"] if classification_type == "augmentation" - else config.datasets.composite + else config["datasets"]["composite"] ) articles = load_classification_dataset(dataset_path) @@ -69,25 +66,25 @@ def classification_pipeline(config: Optional[Dict] = None): classifications = classify_articles( articles=articles, hf_token=hf_token, - model_id=config.model_repo_ids.deepseek, - inference_params=pipeline_config.inference_params, + model_id=config["model_repo_ids"]["deepseek"], + inference_params=classify_config["inference_params"], classification_type=classification_type, - batch_config=pipeline_config.batch_processing, - parallel_config=pipeline_config.parallel_processing, - checkpoint_config=pipeline_config.checkpoint, + batch_config=classify_config["batch_processing"], + parallel_config=classify_config["parallel_processing"], + checkpoint_config=classify_config["checkpoint"], ) results_path = save_classifications( classifications=classifications, classification_type=classification_type, - model_id=config.model_repo_ids.deepseek, - inference_params=pipeline_config.inference_params, - batch_config=pipeline_config.batch_processing, - checkpoint_config=pipeline_config.checkpoint, + model_id=config["model_repo_ids"]["deepseek"], + inference_params=classify_config["inference_params"], + batch_config=classify_config["batch_processing"], + checkpoint_config=classify_config["checkpoint"], ) if classification_type == "evaluation": - base_dataset_path = config.datasets.composite + base_dataset_path = config["datasets"]["composite"] calculate_and_save_metrics_from_json( results_path=str(results_path), base_dataset_path=base_dataset_path, @@ -96,7 +93,7 @@ def classification_pipeline(config: Optional[Dict] = None): if classification_type == "augmentation": merge_classifications( results_path=results_path, - training_dataset_path=config.datasets.composite, - augmented_dataset_path=config.datasets.augmented, - source_dataset_path=config.datasets.unclassified, + training_dataset_path=config["datasets"]["composite"], + augmented_dataset_path=config["datasets"]["augmented"], + source_dataset_path=config["datasets"]["unclassified"], )