1515# limitations under the License.
1616#
1717
18- """
19- Pipeline for article classification and dataset processing.
20- """
18+ """Pipeline for article classification and dataset processing."""
2119
2220from typing import Dict , Optional
2321
3836
3937@pipeline (enable_cache = False )
4038def classification_pipeline (config : Optional [Dict ] = None ):
41- """
42- Pipeline for article classification and dataset processing.
39+ """Pipeline for article classification and dataset processing.
4340
4441 Args:
4542 config: Pipeline configuration from base_config.yaml
@@ -53,41 +50,41 @@ def classification_pipeline(config: Optional[Dict] = None):
5350
5451 hf_token = get_hf_token ()
5552
56- pipeline_config = config . steps . classify
57- classification_type = pipeline_config . classification_type
53+ classify_config = config [ " steps" ][ " classify" ]
54+ classification_type = classify_config [ " classification_type" ]
5855
5956 logger .log_classification_type (classification_type )
6057
6158 dataset_path = (
62- config . datasets . unclassified
59+ config [ " datasets" ][ " unclassified" ]
6360 if classification_type == "augmentation"
64- else config . datasets . composite
61+ else config [ " datasets" ][ " composite" ]
6562 )
6663
6764 articles = load_classification_dataset (dataset_path )
6865
6966 classifications = classify_articles (
7067 articles = articles ,
7168 hf_token = hf_token ,
72- model_id = config . model_repo_ids . deepseek ,
73- inference_params = pipeline_config . inference_params ,
69+ model_id = config [ " model_repo_ids" ][ " deepseek" ] ,
70+ inference_params = classify_config [ " inference_params" ] ,
7471 classification_type = classification_type ,
75- batch_config = pipeline_config . batch_processing ,
76- parallel_config = pipeline_config . parallel_processing ,
77- checkpoint_config = pipeline_config . checkpoint ,
72+ batch_config = classify_config [ " batch_processing" ] ,
73+ parallel_config = classify_config [ " parallel_processing" ] ,
74+ checkpoint_config = classify_config [ " checkpoint" ] ,
7875 )
7976
8077 results_path = save_classifications (
8178 classifications = classifications ,
8279 classification_type = classification_type ,
83- model_id = config . model_repo_ids . deepseek ,
84- inference_params = pipeline_config . inference_params ,
85- batch_config = pipeline_config . batch_processing ,
86- checkpoint_config = pipeline_config . checkpoint ,
80+ model_id = config [ " model_repo_ids" ][ " deepseek" ] ,
81+ inference_params = classify_config [ " inference_params" ] ,
82+ batch_config = classify_config [ " batch_processing" ] ,
83+ checkpoint_config = classify_config [ " checkpoint" ] ,
8784 )
8885
8986 if classification_type == "evaluation" :
90- base_dataset_path = config . datasets . composite
87+ base_dataset_path = config [ " datasets" ][ " composite" ]
9188 calculate_and_save_metrics_from_json (
9289 results_path = str (results_path ),
9390 base_dataset_path = base_dataset_path ,
@@ -96,7 +93,7 @@ def classification_pipeline(config: Optional[Dict] = None):
9693 if classification_type == "augmentation" :
9794 merge_classifications (
9895 results_path = results_path ,
99- training_dataset_path = config . datasets . composite ,
100- augmented_dataset_path = config . datasets . augmented ,
101- source_dataset_path = config . datasets . unclassified ,
96+ training_dataset_path = config [ " datasets" ][ " composite" ] ,
97+ augmented_dataset_path = config [ " datasets" ][ " augmented" ] ,
98+ source_dataset_path = config [ " datasets" ][ " unclassified" ] ,
10299 )
0 commit comments