Skip to content

Commit e59e1e0

Browse files
Julien RousselJulien Roussel
authored andcommitted
Merge branch 'dev' of https://github.com/Quantmetry/qolmat into dev
2 parents 6390bc7 + 6c3eeb6 commit e59e1e0

File tree

10 files changed

+349
-84
lines changed

10 files changed

+349
-84
lines changed

.coveragerc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[run]
2+
omit = qolmat/_version.py

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ examples/*.ipynb
5959
examples/figures/*
6060
examples/data/*
6161
examples/local
62-
62+
data/data_local/*
6363

6464
# VSCode
6565
.vscode

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[pytest]
2+
addopts = --cov=qolmat

qolmat/benchmark/comparator.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_errors(
5252
df_origin: pd.DataFrame,
5353
df_imputed: pd.DataFrame,
5454
df_mask: pd.DataFrame,
55-
) -> pd.Series:
55+
) -> pd.DataFrame:
5656
"""Functions evaluating the reconstruction's quality
5757
5858
Parameters
@@ -64,15 +64,15 @@ def get_errors(
6464
6565
Returns
6666
-------
67-
dictionary
68-
dictionay of results obtained via different metrics
67+
pd.DataFrame
68+
DataFrame of results obtained via different metrics
6969
"""
7070
dict_errors = {}
7171
for name_metric in self.metrics:
7272
fun_metric = metrics.get_metric(name_metric)
7373
dict_errors[name_metric] = fun_metric(df_origin, df_imputed, df_mask)
74-
errors = pd.concat(dict_errors.values(), keys=dict_errors.keys())
75-
return errors
74+
df_errors = pd.concat(dict_errors.values(), keys=dict_errors.keys())
75+
return df_errors
7676

7777
def evaluate_errors_sample(
7878
self,
@@ -96,8 +96,8 @@ def evaluate_errors_sample(
9696
9797
Returns
9898
-------
99-
pd.DataFrame
100-
DataFrame with the errors for each metric (in column) and at each fold (in index)
99+
pd.Series
100+
Series with the errors for each metric and each variable
101101
"""
102102
list_errors = []
103103
df_origin = df[self.selected_columns].copy()
@@ -115,8 +115,12 @@ def evaluate_errors_sample(
115115
)
116116
df_imputed = imputer_opti.fit_transform(df_corrupted)
117117
subset = self.generator_holes.subset
118-
errors = self.get_errors(df_origin[subset], df_imputed[subset], df_mask[subset])
119-
list_errors.append(errors)
118+
if subset is None:
119+
raise ValueError(
120+
"HoleGenerator `subset` should be overwritten in split but it is none!"
121+
)
122+
df_errors = self.get_errors(df_origin[subset], df_imputed[subset], df_mask[subset])
123+
list_errors.append(df_errors)
120124
df_errors = pd.DataFrame(list_errors)
121125
errors_mean = df_errors.mean(axis=0)
122126

@@ -136,7 +140,8 @@ def compare(
136140
Returns
137141
-------
138142
pd.DataFrame
139-
dataframe with imputation
143+
Dataframe with the metrics results, imputers are in columns and indices represent
144+
metrics and variables.
140145
"""
141146

142147
dict_errors = {}

qolmat/benchmark/metrics.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Callable, Dict, List, Optional
2+
from typing import Callable, Dict, List
33

44
import numpy as np
55
import pandas as pd
@@ -1030,7 +1030,9 @@ def pattern_based_weighted_mean_metric(
10301030
return pd.Series(sum([s * w for s, w in zip(scores, weights)]), index=["All"])
10311031

10321032

1033-
def get_metric(name: str) -> Callable:
1033+
def get_metric(
1034+
name: str,
1035+
) -> Callable[[pd.DataFrame, pd.DataFrame, pd.DataFrame], pd.Series]:
10341036
dict_metrics: Dict[str, Callable] = {
10351037
"mse": mean_squared_error,
10361038
"rmse": root_mean_squared_error,

qolmat/utils/data.py

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import zipfile
44
from datetime import datetime
55
from math import pi
6-
from typing import List
6+
from typing import List, Tuple, Union
77
from urllib import request
88

99
import numpy as np
@@ -36,6 +36,24 @@ def read_csv_local(data_file_name: str, **kwargs) -> pd.DataFrame:
3636
def download_data_from_zip(
3737
zipname: str, urllink: str, datapath: str = "data/"
3838
) -> List[pd.DataFrame]:
39+
"""
40+
Downloads and extracts ZIP files from a URL, then loads DataFrames from CSV files.
41+
42+
Parameters
43+
----------
44+
zipname : str
45+
Name of the ZIP file to download, without the '.zip' extension.
46+
urllink : str
47+
Base URL where the ZIP file is hosted.
48+
datapath : str, optional
49+
Path to the directory where the ZIP will be downloaded and extracted.
50+
Defaults to 'data/'.
51+
52+
Returns
53+
-------
54+
List[pd.DataFrame]
55+
A list of DataFrames loaded from the CSV files within the extracted directory.
56+
"""
3957
path_zip = os.path.join(datapath, zipname)
4058
path_zip_ext = path_zip + ".zip"
4159
url = os.path.join(urllink, zipname) + ".zip"
@@ -50,6 +68,23 @@ def download_data_from_zip(
5068

5169

5270
def get_dataframes_in_folder(path: str, extension: str) -> List[pd.DataFrame]:
71+
"""
72+
Loads all dataframes from files with a specified extension within a directory, including
73+
subdirectories. Special handling for '.tsf' files which are converted and immediately returned.
74+
75+
Parameters
76+
----------
77+
path : str
78+
Path to the directory to search for files.
79+
extension : str
80+
File extension to filter files by, e.g., '.csv'.
81+
82+
Returns
83+
-------
84+
List[pd.DataFrame]
85+
A list of pandas DataFrames loaded from the files matching the extension.
86+
If a '.tsf' file is found, its converted DataFrame is returned immediately.
87+
"""
5388
list_df = []
5489
for folder, _, files in os.walk(path):
5590
for file in files:
@@ -61,7 +96,37 @@ def get_dataframes_in_folder(path: str, extension: str) -> List[pd.DataFrame]:
6196
return list_df
6297

6398

64-
def generate_artificial_ts(n_samples, periods, amp_anomalies, ratio_anomalies, amp_noise):
99+
def generate_artificial_ts(
100+
n_samples: int,
101+
periods: List[int],
102+
amp_anomalies: float,
103+
ratio_anomalies: float,
104+
amp_noise: float,
105+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
106+
"""
107+
Generates time series data, anomalies, and noise based on given parameters.
108+
109+
Parameters
110+
----------
111+
n_samples : int
112+
Number of samples in the time series.
113+
periods : List[int]
114+
List of periods that are added to the time series.
115+
amp_anomalies : float
116+
Amplitude multiplier for anomalies.
117+
ratio_anomalies : float
118+
Ratio of total samples that will be anomalies.
119+
amp_noise : float
120+
Standard deviation of Gaussian noise.
121+
122+
Returns
123+
-------
124+
Tuple[np.ndarray, np.ndarray, np.ndarray]
125+
Time series data with sine waves (X).
126+
Anomaly data with specified amplitudes at random positions (A).
127+
Gaussian noise added to the time series (E).
128+
"""
129+
65130
mesh = np.arange(n_samples)
66131
X = np.ones(n_samples)
67132
for p in periods:
@@ -83,7 +148,8 @@ def get_data(
83148
datapath: str = "data/",
84149
n_groups_max: int = sys.maxsize,
85150
) -> pd.DataFrame:
86-
"""Download or generate data
151+
"""
152+
Download or generate data
87153
88154
Parameters
89155
----------
@@ -102,39 +168,16 @@ def get_data(
102168
if name_data == "Beijing":
103169
df = read_csv_local("beijing")
104170
df["date"] = pd.to_datetime(df["date"])
105-
106-
# df["date"] = pd.to_datetime(
107-
# {
108-
# "year": df["year"],
109-
# "month": df["month"],
110-
# "day": df["day"],
111-
# "hour": df["hour"],
112-
# }
113-
# )
114171
df = df.drop(columns=["year", "month", "day", "hour", "wd"])
115-
# df = df.set_index(["station", "date"])
116172
df = df.groupby(["station", "date"]).mean()
117173
return df
118174
elif name_data == "Superconductor":
119175
df = read_csv_local("conductors")
120176
return df
121177
elif name_data == "Titanic":
122-
# df = read_csv_local("titanic", sep=";")
123178
path = "https://gist.githubusercontent.com/fyyying/4aa5b471860321d7b47fd881898162b7/raw/"
124179
"6907bb3a38bfbb6fccf3a8b1edfb90e39714d14f/titanic_dataset.csv"
125180
df = pd.read_csv(path)
126-
# df = df.dropna(how="all")
127-
# df = df.drop(
128-
# columns=[
129-
# "pclass",
130-
# "name",
131-
# "home.dest",
132-
# "cabin",
133-
# "ticket",
134-
# "boat",
135-
# "body",
136-
# ]
137-
# )
138181
df = df[["Survived", "Sex", "Age", "SibSp", "Parch", "Fare", "Embarked"]]
139182
df["Age"] = pd.to_numeric(df["Age"], errors="coerce")
140183
df["Fare"] = pd.to_numeric(df["Fare"], errors="coerce")
@@ -276,22 +319,16 @@ def add_holes(df: pd.DataFrame, ratio_masked: float, mean_size: int) -> pd.DataF
276319
277320
ratio_masked : float
278321
Targeted global proportion of nans added in the returned dataset
279-
280-
groups: list of strings
281-
List of the column names used as groups
282-
283322
Returns
284323
-------
285324
pd.DataFrame
286325
dataframe with missing values
287326
"""
288-
try:
289-
groups = df.index.names.difference(["datetime", "date", "index"])
327+
groups = df.index.names.difference(["datetime", "date", "index"])
328+
if groups != []:
290329
generator = missing_patterns.GeometricHoleGenerator(
291330
1, ratio_masked=ratio_masked, subset=df.columns, groups=groups
292331
)
293-
except ValueError:
294-
print("No group")
295332
else:
296333
generator = missing_patterns.GeometricHoleGenerator(
297334
1, ratio_masked=ratio_masked, subset=df.columns
@@ -392,42 +429,27 @@ def convert_tsf_to_dataframe(
392429
col_types = []
393430
all_data = {}
394431
line_count = 0
395-
# frequency = None
396-
# forecast_horizon = None
397-
# contain_missing_values = None
398-
# contain_equal_length = None
399432
found_data_tag = False
400433
found_data_section = False
401434
started_reading_data_section = False
402435

403436
with open(full_file_path_and_name, "r", encoding="cp1252") as file:
404437
for line in file:
405-
# Strip white space from start/end of line
406438
line = line.strip()
407439

408440
if line:
409-
if line.startswith("@"): # Read meta-data
441+
if line.startswith("@"):
410442
if not line.startswith("@data"):
411443
line_content = line.split(" ")
412444
if line.startswith("@attribute"):
413-
if len(line_content) != 3: # Attributes have both name and type
445+
if len(line_content) != 3:
414446
raise Exception("Invalid meta-data specification.")
415447

416448
col_names.append(line_content[1])
417449
col_types.append(line_content[2])
418450
else:
419-
if len(line_content) != 2: # Other meta-data have only values
451+
if len(line_content) != 2:
420452
raise Exception("Invalid meta-data specification.")
421-
422-
# if line.startswith("@frequency"):
423-
# frequency = line_content[1]
424-
# elif line.startswith("@horizon"):
425-
# forecast_horizon = int(line_content[1])
426-
# elif line.startswith("@missing"):
427-
# contain_missing_values = bool(strtobool(line_content[1]))
428-
# elif line.startswith("@equallength"):
429-
# contain_equal_length = bool(strtobool(line_content[1]))
430-
431453
else:
432454
if len(col_names) == 0:
433455
raise Exception("Attribute section must come before data.")

tests/benchmark/test_comparator.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import pytest
2+
import numpy as np
3+
import pandas as pd
4+
5+
from unittest.mock import patch, MagicMock
6+
from qolmat.benchmark.comparator import Comparator
7+
8+
generator_holes_mock = MagicMock()
9+
generator_holes_mock.split.return_value = [
10+
pd.DataFrame({"A": [False, False, True], "B": [True, False, False]})
11+
]
12+
13+
comparator = Comparator(
14+
dict_models={},
15+
selected_columns=["A", "B"],
16+
generator_holes=generator_holes_mock,
17+
metrics=["mae", "mse"],
18+
)
19+
20+
imputer_mock = MagicMock()
21+
expected_get_errors = pd.Series(
22+
[1.0, 1.0, 1.0, 1.0],
23+
index=pd.MultiIndex.from_tuples([("mae", "A"), ("mae", "B"), ("mse", "A"), ("mse", "B")]),
24+
)
25+
26+
27+
@patch("qolmat.benchmark.metrics.get_metric")
28+
def test_get_errors(mock_get_metric):
29+
df_origin = pd.DataFrame({"A": [1, np.nan, 3], "B": [np.nan, 5, 6]})
30+
df_imputed = pd.DataFrame({"A": [1, 2, 4], "B": [4, 5, 7]})
31+
df_mask = pd.DataFrame({"A": [False, False, True], "B": [False, False, True]})
32+
33+
mock_get_metric.return_value = lambda df_origin, df_imputed, df_mask: pd.Series(
34+
[1.0, 1.0], index=["A", "B"]
35+
)
36+
errors = comparator.get_errors(df_origin, df_imputed, df_mask)
37+
pd.testing.assert_series_equal(errors, expected_get_errors)
38+
39+
40+
@patch("qolmat.benchmark.hyperparameters.optimize", return_value=imputer_mock)
41+
@patch(
42+
"qolmat.benchmark.comparator.Comparator.get_errors",
43+
return_value=expected_get_errors,
44+
)
45+
def test_evaluate_errors_sample(mock_get_errors, mock_optimize):
46+
errors_mean = comparator.evaluate_errors_sample(
47+
imputer_mock, pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, np.nan]})
48+
)
49+
expected_errors_mean = expected_get_errors
50+
pd.testing.assert_series_equal(errors_mean, expected_errors_mean)
51+
mock_optimize.assert_called_once()
52+
mock_get_errors.assert_called()
53+
54+
55+
@patch(
56+
"qolmat.benchmark.comparator.Comparator.evaluate_errors_sample",
57+
return_value=expected_get_errors,
58+
)
59+
def test_compare(mock_evaluate_errors_sample):
60+
df_test = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
61+
62+
imputer1 = MagicMock(name="Imputer1")
63+
imputer2 = MagicMock(name="Imputer2")
64+
comparator.dict_imputers = {"imputer1": imputer1, "imputer2": imputer2}
65+
66+
errors_imputer1 = pd.Series([0.1, 0.2], index=["mae", "mse"])
67+
errors_imputer2 = pd.Series([0.3, 0.4], index=["mae", "mse"])
68+
mock_evaluate_errors_sample.side_effect = [errors_imputer1, errors_imputer2]
69+
70+
df_errors = comparator.compare(df_test)
71+
assert mock_evaluate_errors_sample.call_count == 2
72+
73+
mock_evaluate_errors_sample.assert_any_call(imputer1, df_test, {}, "mse")
74+
mock_evaluate_errors_sample.assert_any_call(imputer2, df_test, {}, "mse")
75+
expected_df_errors = pd.DataFrame(
76+
{"imputer1": [0.1, 0.2], "imputer2": [0.3, 0.4]}, index=["mae", "mse"]
77+
)
78+
pd.testing.assert_frame_equal(df_errors, expected_df_errors)

0 commit comments

Comments
 (0)