Skip to content

Commit 5df4172

Browse files
Merge pull request #74 from stefanDeveloper/documentation/add-docstrings
Update missing docstrings or docstrings with too few information
2 parents e0b0f9c + 3dc130c commit 5df4172

File tree

8 files changed

+140
-12
lines changed

8 files changed

+140
-12
lines changed

src/base/clickhouse_kafka_sender.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ def __init__(self, table_name: str):
2727
)()
2828

2929
def insert(self, data: dict):
30-
"""Produces the insert operation to Kafka."""
30+
"""
31+
Produces the insert operation to Kafka.
32+
33+
Args:
34+
data (dict): content to write into the Kafka queue
35+
"""
3136
self.kafka_producer.produce(
3237
topic=f"clickhouse_{self.table_name}",
3338
data=self.data_schema.dumps(data),

src/base/data_classes/batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
@dataclass
1010
class Batch:
11+
"""
12+
Class definition of a batch, used to divide the log input into smaller amounts
13+
"""
14+
1115
batch_id: uuid.UUID = field(
1216
metadata={"marshmallow_field": marshmallow.fields.UUID()}
1317
)

src/base/kafka_handler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,15 @@ def consume_as_json(self) -> tuple[None | str, dict]:
284284
except Exception:
285285
raise ValueError("Unknown data format")
286286

287-
def _all_topics_created(self, topics):
287+
def _all_topics_created(self, topics) -> bool:
288+
"""
289+
Checks whether each topic in a list of topics was created. If not, retries for a set amount of times
290+
291+
Args:
292+
topics (list): List of topics to check
293+
Returns:
294+
bool
295+
"""
288296
number_of_retries_left = 30
289297
all_topics_created = False
290298
while not all_topics_created: # try for 15 seconds

src/detector/detector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def calculate_entropy(s: str) -> float:
299299
return all_features.reshape(1, -1)
300300

301301
def detect(self) -> None: # pragma: no cover
302+
"""Method to detect malicious requests in the network flows"""
302303
logger.info("Start detecting malicious requests.")
303304
for message in self.messages:
304305
# TODO predict all messages
@@ -317,6 +318,7 @@ def detect(self) -> None: # pragma: no cover
317318
self.warnings.append(warning)
318319

319320
def send_warning(self) -> None:
321+
"""Dispatch warnings saved to the object's warning list"""
320322
logger.info("Store alert.")
321323
if len(self.warnings) > 0:
322324
overall_score = median(

src/inspector/inspector.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,20 @@ def inspect(self):
336336
raise NotImplementedError(f"Mode {MODE} is not supported!")
337337

338338
def _inspect_multivariate(self, model: str):
339+
"""
340+
Method to inspect multivariate data for anomalies using a StreamAD Model
341+
Errors are count in the time window and fit model to retrieve scores.
342+
343+
Args:
344+
model (str): Model name (should be capable of handling multivariate data)
345+
346+
"""
339347
logger.debug(f"Load Model: {model['model']} from {model['module']}.")
340348
if not model["model"] in VALID_MULTIVARIATE_MODELS:
341-
logger.error(f"Model {model} is not a valid univariate model.")
342-
raise NotImplementedError(f"Model {model} is not a valid univariate model.")
349+
logger.error(f"Model {model} is not a valid multivariate model.")
350+
raise NotImplementedError(
351+
f"Model {model} is not a valid multivariate model."
352+
)
343353

344354
module = importlib.import_module(model["module"])
345355
module_model = getattr(module, model["model"])
@@ -367,11 +377,19 @@ def _inspect_multivariate(self, model: str):
367377
self.anomalies.append(0)
368378

369379
def _inspect_ensemble(self, models: str):
380+
"""
381+
Method to inspect data for anomalies using ensembles of two StreamAD models
382+
Errors are count in the time window and fit model to retrieve scores.
383+
384+
Args:
385+
model (str): Model name (should be a valid ensemble modle)
386+
387+
"""
370388
logger.debug(f"Load Model: {ENSEMBLE['model']} from {ENSEMBLE['module']}.")
371389
if not ENSEMBLE["model"] in VALID_ENSEMBLE_MODELS:
372-
logger.error(f"Model {ENSEMBLE} is not a valid univariate model.")
390+
logger.error(f"Model {ENSEMBLE} is not a valid ensemble model.")
373391
raise NotImplementedError(
374-
f"Model {ENSEMBLE} is not a valid univariate model."
392+
f"Model {ENSEMBLE} is not a valid ensemble model."
375393
)
376394

377395
module = importlib.import_module(ENSEMBLE["module"])
@@ -389,9 +407,9 @@ def _inspect_ensemble(self, models: str):
389407
for model in models:
390408
logger.debug(f"Load Model: {model['model']} from {model['module']}.")
391409
if not model["model"] in VALID_UNIVARIATE_MODELS:
392-
logger.error(f"Model {models} is not a valid univariate model.")
410+
logger.error(f"Model {models} is not a valid ensemble model.")
393411
raise NotImplementedError(
394-
f"Model {models} is not a valid univariate model."
412+
f"Model {models} is not a valid ensemble model."
395413
)
396414

397415
module = importlib.import_module(model["module"])
@@ -415,8 +433,7 @@ def _inspect_univariate(self, model: str):
415433
Errors are count in the time window and fit model to retrieve scores.
416434
417435
Args:
418-
model (BaseDetector): StreamAD model.
419-
model_args (dict): Arguments passed to the StreamAD model.
436+
model (str): StreamAD model name.
420437
"""
421438

422439
logger.debug(f"Load Model: {model['model']} from {model['module']}.")
@@ -445,6 +462,7 @@ def _inspect_univariate(self, model: str):
445462
self.anomalies.append(0)
446463

447464
def send_data(self):
465+
"""Pass the anomalous data for the detector unit for further processing"""
448466
total_anomalies = np.count_nonzero(
449467
np.greater_equal(np.array(self.anomalies), SCORE_THRESHOLD)
450468
)

src/logcollector/batch_handler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,12 @@ def add_message(self, key: str, message: str) -> None:
406406
self._reset_timer()
407407

408408
def _send_all_batches(self, reset_timer: bool = True) -> None:
409+
"""
410+
Dispatch all batches for the Kafka queue
411+
412+
Args:
413+
reset_timer (bool): whether or not the timer should be reset
414+
"""
409415
number_of_keys = 0
410416
total_number_of_batch_messages = self.batch.get_message_count_for_batch()
411417
total_number_of_buffer_messages = self.batch.get_message_count_for_buffer()
@@ -438,6 +444,12 @@ def _send_all_batches(self, reset_timer: bool = True) -> None:
438444
)
439445

440446
def _send_batch_for_key(self, key: str) -> None:
447+
"""
448+
Send one batch based on the key
449+
450+
Args:
451+
key (str): Key to identify the batch
452+
"""
441453
try:
442454
data = self.batch.complete_batch(key)
443455
except ValueError as e:
@@ -447,6 +459,13 @@ def _send_batch_for_key(self, key: str) -> None:
447459
self._send_data_packet(key, data)
448460

449461
def _send_data_packet(self, key: str, data: dict) -> None:
462+
"""
463+
Sends a packet of a batch to the defined Kafka topic
464+
465+
Args:
466+
key (str): key to identify the batch
467+
data (dict): the batch data to send
468+
"""
450469
batch_schema = marshmallow_dataclass.class_schema(Batch)()
451470

452471
self.kafka_produce_handler.produce(
@@ -456,6 +475,7 @@ def _send_data_packet(self, key: str, data: dict) -> None:
456475
)
457476

458477
def _reset_timer(self) -> None:
478+
"""Restarts the internal timer of the object"""
459479
if self.timer:
460480
self.timer.cancel()
461481

src/monitoring/clickhouse_batch_sender.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ class Table:
3232
columns: dict[str, type]
3333

3434
def verify(self, data: dict[str, Any]):
35+
"""
36+
Verify if the data has the correct columns and types.
37+
38+
Args:
39+
data (dict): The values for each cell
40+
"""
3541
if len(data) != len(self.columns):
3642
raise ValueError(
3743
f"Wrong number of fields in data: Expected {len(self.columns)}, got {len(data)}"
@@ -182,7 +188,14 @@ def __del__(self):
182188
self.insert_all()
183189

184190
def add(self, table_name: str, data: dict[str, Any]):
185-
"""Adds the data to the batch for the table. Verifies the fields first."""
191+
"""
192+
Adds the data to the batch for the table. Verifies the fields first.
193+
194+
Args:
195+
table_name (str): Name of the table to add data to
196+
data (dict): The values for each cell in the table
197+
198+
"""
186199
self.tables.get(table_name).verify(data)
187200
self.batch.get(table_name).append(list(data.values()))
188201

@@ -193,7 +206,12 @@ def add(self, table_name: str, data: dict[str, Any]):
193206
self._start_timer()
194207

195208
def insert(self, table_name: str):
196-
"""Inserts the batch for the given table."""
209+
"""
210+
Inserts the batch for the given table.
211+
212+
Args:
213+
table_name (str): Name of the table to insert data to
214+
"""
197215
if self.batch[table_name]:
198216
with self.lock:
199217
self._client.insert(
@@ -217,6 +235,7 @@ def insert_all(self):
217235
self.timer = None
218236

219237
def _start_timer(self):
238+
"""Set the timer for batch processing of data insertion"""
220239
if self.timer:
221240
self.timer.cancel()
222241

src/train/model.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,16 @@ def __init__(
165165
super().__init__(processor, x_train, y_train)
166166

167167
def fdr_metric(self, preds: np.ndarray, dtrain: xgb.DMatrix) -> tuple[str, float]:
168+
"""
169+
Custom FDR metric to evaluate model performance based on False Discovery Rate.
170+
171+
Args:
172+
preds (np.ndarray): The predicted values.
173+
dtrain (xgb.DMatrix): The training data matrix.
174+
175+
Returns:
176+
tuple: A tuple containing the metric name ("fdr") and its value.
177+
"""
168178
# Get the true labels
169179
labels = dtrain.get_label()
170180

@@ -188,6 +198,15 @@ def fdr_metric(self, preds: np.ndarray, dtrain: xgb.DMatrix) -> tuple[str, float
188198
) # -1 is essentiell since XGBoost wants a scoring value (higher is better). However, FDR represents a loss function.
189199

190200
def objective(self, trial):
201+
"""
202+
Optimizes the XGBoost model hyperparameters using cross-validation.
203+
204+
Args:
205+
trial: A trial object from the optimization framework (e.g., Optuna).
206+
207+
Returns:
208+
float: The best FDR value after cross-validation.
209+
"""
191210
dtrain = xgb.DMatrix(self.x_train, label=self.y_train)
192211

193212
param = {
@@ -263,6 +282,13 @@ def predict(self, x):
263282
return self.clf.predict(x)
264283

265284
def train(self, trial, output_path):
285+
"""
286+
Trains the XGBoost model and saves the trained model to a file.
287+
288+
Args:
289+
trial: A trial object from the optimization framework.
290+
output_path (str): The directory path to save the trained model.
291+
"""
266292
logger.info("Number of estimators: {}".format(trial.user_attrs["n_estimators"]))
267293

268294
# dtrain = xgb.DMatrix(self.x_train, label=self.y_train)
@@ -300,6 +326,16 @@ def __init__(
300326

301327
# Define the custom FDR metric
302328
def fdr_metric(self, y_true: np.ndarray, y_pred: np.ndarray):
329+
"""
330+
Custom FDR metric to evaluate the performance of the Random Forest model.
331+
332+
Args:
333+
y_true (np.ndarray): The true labels.
334+
y_pred (np.ndarray): The predicted labels.
335+
336+
Returns:
337+
float: The False Discovery Rate (FDR).
338+
"""
303339
# False Positives (FP): cases where the model predicted 1 but the actual label is 0
304340
FP = np.sum((y_pred == 1) & (y_true == 0))
305341

@@ -315,6 +351,15 @@ def fdr_metric(self, y_true: np.ndarray, y_pred: np.ndarray):
315351
return fdr
316352

317353
def objective(self, trial):
354+
"""
355+
Optimizes the Random Forest model hyperparameters using cross-validation.
356+
357+
Args:
358+
trial: A trial object from the optimization framework (e.g., Optuna).
359+
360+
Returns:
361+
float: The best FDR value after cross-validation.
362+
"""
318363
# Define hyperparameters to optimize
319364
n_estimators = trial.suggest_int("n_estimators", 50, 300)
320365
max_depth = trial.suggest_int("max_depth", 2, 20)
@@ -359,6 +404,13 @@ def predict(self, x):
359404
return self.clf.predict(x)
360405

361406
def train(self, trial, output_path):
407+
"""
408+
Trains the Random Forest model and saves the trained model to a file.
409+
410+
Args:
411+
trial: A trial object from the optimization framework.
412+
output_path (str): The directory path to save the trained model.
413+
"""
362414
self.clf = RandomForestClassifier(**trial.params)
363415
self.clf.fit(self.x_train, self.y_train)
364416

0 commit comments

Comments
 (0)