-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
211 lines (166 loc) · 7.82 KB
/
train.py
File metadata and controls
211 lines (166 loc) · 7.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import argparse
import json
import logging
import os
import time
from pathlib import Path
import pymupdf
from dotenv import load_dotenv
from sklearn.model_selection import RandomizedSearchCV
from swissgeol_doc_processing.utils.file_utils import read_params as swissgeol_read_params
from tqdm import tqdm
from xgboost import XGBClassifier
from src.models.feature_engineering import get_features
from src.models.treebased.basetrainer import TreeBasedTrainer
from src.models.treebased.model_explanation import explain_model
from src.page_classes import label2id
from src.utils.utility import get_pdf_files, read_params
logger = logging.getLogger(__name__)
load_dotenv()
mlflow_tracking = os.getenv("MLFLOW_TRACKING") == "True"
if mlflow_tracking:
import mlflow
MATCHING_PARAMS_PATH = "config/local_matching_params.yml"
matching_params = read_params(MATCHING_PARAMS_PATH)
borehole_matching_params = swissgeol_read_params("matching_params.yml")
class XGBoostTrainer(TreeBasedTrainer):
"""Trainer for XGBoost models.
This class extends the TreeBasedTrainer to implement specific methods for training and evaluating
XGBoost models using the provided configuration and data.
"""
model_name = "xgboost_model"
def prepare_model(self):
"""Prepares the XGBoost model for training."""
hyperparams = self.config.get("hyperparameters", {})
self.model = XGBClassifier(objective="multi:softprob", num_class=self.num_labels, **hyperparams)
def tune_hyperparameters(
self, param_dist: dict, n_iter: int = 20, scoring: str = "f1_micro", cv: int = 3, random_state: int = 42
) -> tuple[dict, float]:
"""Runs RandomizedSearchCV to tune hyperparameters for XGBoost.
Args:
param_dist: Dictionary with parameters to search.
n_iter: Number of parameter settings that are sampled.
scoring: Scoring method to use for evaluation.
cv: Number of folds in cross-validation.
random_state (int): Random seed for reproducibility.
Returns:
best_params: Best hyperparameters found during tuning.
best_score: Best score achieved during tuning.
"""
# Initialize XGBoost model with default parameters
model = XGBClassifier(objective="multi:softprob", num_class=self.num_labels, eval_metric="mlogloss")
search = RandomizedSearchCV(
estimator=model,
param_distributions=param_dist,
n_iter=n_iter,
scoring=scoring,
cv=cv,
verbose=1,
random_state=random_state,
n_jobs=-1,
)
search.fit(self.X_train, self.y_train)
return search.best_params_, search.best_score_
def load_data_and_labels(folder_path: Path, label_map: dict[tuple[str, int], int]):
"""Loads data and labels from PDF files in the specified folder.
Args:
folder_path (Path): Path to the folder containing PDF files.
label_map (dict): Mapping from (filename, page_number) to label ID.
Returns:
tuple: A tuple containing a list of features and a list of labels.
"""
file_paths = get_pdf_files(folder_path)
all_features = []
labels = []
for file_path in tqdm(file_paths, desc="Loading data ..."):
filename = os.path.basename(file_path)
with pymupdf.Document(file_path) as doc:
for page_number, page in enumerate(doc, start=1):
key = (filename, page_number)
if key not in label_map:
raise ValueError(f"Missing label for file: {key}")
# Extract feature for given document page
features = get_features(page, page_number, matching_params, borehole_matching_params)
labels.append(label_map[key])
all_features.append(features)
return all_features, labels
def build_filename_to_label_map(gt_json_path: Path) -> dict[tuple[str, int], int]:
"""Build a map from filename to class ID based on the ground truth JSON."""
with open(gt_json_path) as f:
gt_data = json.load(f)
label_lookup = {}
for entry in gt_data:
filename = entry["filename"]
for pages in entry["pages"]:
page = pages["page"]
for label_name, value in pages["classification"].items():
if value == 1:
try:
label_id = label2id[label_name]
label_lookup[(filename, page)] = label_id
except KeyError as err:
raise ValueError(f"Unknown label: {label_name}") from err
return label_lookup
def main(config_path: str, out_directory: str, tuning: bool = False):
"""Main function to train the XGBoost model based on the provided configuration.
Args:
config_path (str): Path to the YAML configuration file.
out_directory (str): Directory where the trained model and logs will be saved.
tuning (bool): Whether to perform hyperparameter tuning. Default is False.
"""
if not mlflow_tracking:
raise RuntimeError("MLflow tracking is disabled. Set MLFLOW_TRACKING=True in .env to enable it.")
mlflow.set_experiment("Classifier Training")
config = read_params(config_path)
train_folder = Path(config["train_folder_path"])
val_folder = Path(config["val_folder_path"])
ground_truth_path = Path(config["ground_truth_file_path"])
trainer_name = config["model_type"]
model_out_directory = Path(out_directory) / time.strftime("%Y%m%d-%H%M%S")
# Create dataset
label_lookup = build_filename_to_label_map(ground_truth_path)
X_train, y_train = load_data_and_labels(train_folder, label_lookup)
X_val, y_val = load_data_and_labels(val_folder, label_lookup)
if trainer_name != "xgboost":
raise ValueError(f"Unsupported trainer: '{trainer_name}'. Only 'xgboost' is supported.")
trainer = XGBoostTrainer(config, model_out_directory)
with mlflow.start_run(run_name=trainer_name):
trainer.load_data(X_train, y_train, X_val, y_val)
if tuning:
search_params = config["tuning"]["param_grid"]
n_iter = config["tuning"].get("n_iter", 20)
scoring = config["tuning"].get("scoring", "f1_micro")
cv = config["tuning"].get("cv", 3)
best_params, best_score = trainer.tune_hyperparameters(
param_dist=search_params,
n_iter=n_iter,
scoring=scoring,
cv=cv,
)
trainer.config["hyperparameters"].update(best_params)
trainer.prepare_model() # with best params
mlflow.log_params(best_params)
mlflow.log_metric("best_cv_score", best_score)
else:
trainer.prepare_model()
trainer.train()
explain_model(trainer.model, trainer.X_train, trainer.id2label)
trainer.save_model()
y_pred = trainer.model.predict(X_val)
metrics = trainer.evaluate(y_pred)
# Log to mlflow
mlflow.log_params(trainer.config.get("hyperparameters", {}))
mlflow.log_metrics(metrics)
mlflow.log_artifact(str(model_out_directory))
if trainer.feature_names:
mlflow.log_dict({"features": trainer.feature_names}, "features.json")
trainer.plot_and_log_feature_importance()
# Log confusion matrix and classification report
trainer.plot_and_log_confusion_matrix(y_pred)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config-file-path", required=True, help="Path to YAML config file")
parser.add_argument("--out-directory", required=True, help="Output directory root")
parser.add_argument("--tuning", action="store_true", help="Enable hyperparameter tuning")
args = parser.parse_args()
main(args.config_file_path, args.out_directory, args.tuning)