@@ -165,6 +165,16 @@ def __init__(
165165 super ().__init__ (processor , x_train , y_train )
166166
167167 def fdr_metric (self , preds : np .ndarray , dtrain : xgb .DMatrix ) -> tuple [str , float ]:
168+ """
169+ Custom FDR metric to evaluate model performance based on False Discovery Rate.
170+
171+ Args:
172+ preds (np.ndarray): The predicted values.
173+ dtrain (xgb.DMatrix): The training data matrix.
174+
175+ Returns:
176+ tuple: A tuple containing the metric name ("fdr") and its value.
177+ """
168178 # Get the true labels
169179 labels = dtrain .get_label ()
170180
@@ -188,6 +198,15 @@ def fdr_metric(self, preds: np.ndarray, dtrain: xgb.DMatrix) -> tuple[str, float
188198 ) # -1 is essentiell since XGBoost wants a scoring value (higher is better). However, FDR represents a loss function.
189199
190200 def objective (self , trial ):
201+ """
202+ Optimizes the XGBoost model hyperparameters using cross-validation.
203+
204+ Args:
205+ trial: A trial object from the optimization framework (e.g., Optuna).
206+
207+ Returns:
208+ float: The best FDR value after cross-validation.
209+ """
191210 dtrain = xgb .DMatrix (self .x_train , label = self .y_train )
192211
193212 param = {
@@ -263,6 +282,13 @@ def predict(self, x):
263282 return self .clf .predict (x )
264283
265284 def train (self , trial , output_path ):
285+ """
286+ Trains the XGBoost model and saves the trained model to a file.
287+
288+ Args:
289+ trial: A trial object from the optimization framework.
290+ output_path (str): The directory path to save the trained model.
291+ """
266292 logger .info ("Number of estimators: {}" .format (trial .user_attrs ["n_estimators" ]))
267293
268294 # dtrain = xgb.DMatrix(self.x_train, label=self.y_train)
@@ -300,6 +326,16 @@ def __init__(
300326
301327 # Define the custom FDR metric
302328 def fdr_metric (self , y_true : np .ndarray , y_pred : np .ndarray ):
329+ """
330+ Custom FDR metric to evaluate the performance of the Random Forest model.
331+
332+ Args:
333+ y_true (np.ndarray): The true labels.
334+ y_pred (np.ndarray): The predicted labels.
335+
336+ Returns:
337+ float: The False Discovery Rate (FDR).
338+ """
303339 # False Positives (FP): cases where the model predicted 1 but the actual label is 0
304340 FP = np .sum ((y_pred == 1 ) & (y_true == 0 ))
305341
@@ -315,6 +351,15 @@ def fdr_metric(self, y_true: np.ndarray, y_pred: np.ndarray):
315351 return fdr
316352
317353 def objective (self , trial ):
354+ """
355+ Optimizes the Random Forest model hyperparameters using cross-validation.
356+
357+ Args:
358+ trial: A trial object from the optimization framework (e.g., Optuna).
359+
360+ Returns:
361+ float: The best FDR value after cross-validation.
362+ """
318363 # Define hyperparameters to optimize
319364 n_estimators = trial .suggest_int ("n_estimators" , 50 , 300 )
320365 max_depth = trial .suggest_int ("max_depth" , 2 , 20 )
@@ -359,6 +404,13 @@ def predict(self, x):
359404 return self .clf .predict (x )
360405
361406 def train (self , trial , output_path ):
407+ """
408+ Trains the Random Forest model and saves the trained model to a file.
409+
410+ Args:
411+ trial: A trial object from the optimization framework.
412+ output_path (str): The directory path to save the trained model.
413+ """
362414 self .clf = RandomForestClassifier (** trial .params )
363415 self .clf .fit (self .x_train , self .y_train )
364416
0 commit comments