Skip to content

Commit 6c21457

Browse files
authored
Merge pull request #182 from zenml-io/fix-research-radar-config
2 parents 452474c + 47e2ad2 commit 6c21457

File tree

2 files changed

+20
-23
lines changed

2 files changed

+20
-23
lines changed

.typos.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ preprocessor = "preprocessor"
5050
logits = "logits"
5151

5252
[default]
53-
locale = "en-us"
53+
locale = "en-us"

research-radar/pipelines/classification.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
"""
19-
Pipeline for article classification and dataset processing.
20-
"""
18+
"""Pipeline for article classification and dataset processing."""
2119

2220
from typing import Dict, Optional
2321

@@ -38,8 +36,7 @@
3836

3937
@pipeline(enable_cache=False)
4038
def classification_pipeline(config: Optional[Dict] = None):
41-
"""
42-
Pipeline for article classification and dataset processing.
39+
"""Pipeline for article classification and dataset processing.
4340
4441
Args:
4542
config: Pipeline configuration from base_config.yaml
@@ -53,41 +50,41 @@ def classification_pipeline(config: Optional[Dict] = None):
5350

5451
hf_token = get_hf_token()
5552

56-
pipeline_config = config.steps.classify
57-
classification_type = pipeline_config.classification_type
53+
classify_config = config["steps"]["classify"]
54+
classification_type = classify_config["classification_type"]
5855

5956
logger.log_classification_type(classification_type)
6057

6158
dataset_path = (
62-
config.datasets.unclassified
59+
config["datasets"]["unclassified"]
6360
if classification_type == "augmentation"
64-
else config.datasets.composite
61+
else config["datasets"]["composite"]
6562
)
6663

6764
articles = load_classification_dataset(dataset_path)
6865

6966
classifications = classify_articles(
7067
articles=articles,
7168
hf_token=hf_token,
72-
model_id=config.model_repo_ids.deepseek,
73-
inference_params=pipeline_config.inference_params,
69+
model_id=config["model_repo_ids"]["deepseek"],
70+
inference_params=classify_config["inference_params"],
7471
classification_type=classification_type,
75-
batch_config=pipeline_config.batch_processing,
76-
parallel_config=pipeline_config.parallel_processing,
77-
checkpoint_config=pipeline_config.checkpoint,
72+
batch_config=classify_config["batch_processing"],
73+
parallel_config=classify_config["parallel_processing"],
74+
checkpoint_config=classify_config["checkpoint"],
7875
)
7976

8077
results_path = save_classifications(
8178
classifications=classifications,
8279
classification_type=classification_type,
83-
model_id=config.model_repo_ids.deepseek,
84-
inference_params=pipeline_config.inference_params,
85-
batch_config=pipeline_config.batch_processing,
86-
checkpoint_config=pipeline_config.checkpoint,
80+
model_id=config["model_repo_ids"]["deepseek"],
81+
inference_params=classify_config["inference_params"],
82+
batch_config=classify_config["batch_processing"],
83+
checkpoint_config=classify_config["checkpoint"],
8784
)
8885

8986
if classification_type == "evaluation":
90-
base_dataset_path = config.datasets.composite
87+
base_dataset_path = config["datasets"]["composite"]
9188
calculate_and_save_metrics_from_json(
9289
results_path=str(results_path),
9390
base_dataset_path=base_dataset_path,
@@ -96,7 +93,7 @@ def classification_pipeline(config: Optional[Dict] = None):
9693
if classification_type == "augmentation":
9794
merge_classifications(
9895
results_path=results_path,
99-
training_dataset_path=config.datasets.composite,
100-
augmented_dataset_path=config.datasets.augmented,
101-
source_dataset_path=config.datasets.unclassified,
96+
training_dataset_path=config["datasets"]["composite"],
97+
augmented_dataset_path=config["datasets"]["augmented"],
98+
source_dataset_path=config["datasets"]["unclassified"],
10299
)

0 commit comments

Comments
 (0)