33import zipfile
44from datetime import datetime
55from math import pi
6- from typing import List
6+ from typing import List , Tuple , Union
77from urllib import request
88
99import numpy as np
@@ -36,6 +36,24 @@ def read_csv_local(data_file_name: str, **kwargs) -> pd.DataFrame:
3636def download_data_from_zip (
3737 zipname : str , urllink : str , datapath : str = "data/"
3838) -> List [pd .DataFrame ]:
39+ """
40+ Downloads and extracts ZIP files from a URL, then loads DataFrames from CSV files.
41+
42+ Parameters
43+ ----------
44+ zipname : str
45+ Name of the ZIP file to download, without the '.zip' extension.
46+ urllink : str
47+ Base URL where the ZIP file is hosted.
48+ datapath : str, optional
49+ Path to the directory where the ZIP will be downloaded and extracted.
50+ Defaults to 'data/'.
51+
52+ Returns
53+ -------
54+ List[pd.DataFrame]
55+ A list of DataFrames loaded from the CSV files within the extracted directory.
56+ """
3957 path_zip = os .path .join (datapath , zipname )
4058 path_zip_ext = path_zip + ".zip"
4159 url = os .path .join (urllink , zipname ) + ".zip"
@@ -50,6 +68,23 @@ def download_data_from_zip(
5068
5169
5270def get_dataframes_in_folder (path : str , extension : str ) -> List [pd .DataFrame ]:
71+ """
72+ Loads all dataframes from files with a specified extension within a directory, including
73+ subdirectories. Special handling for '.tsf' files which are converted and immediately returned.
74+
75+ Parameters
76+ ----------
77+ path : str
78+ Path to the directory to search for files.
79+ extension : str
80+ File extension to filter files by, e.g., '.csv'.
81+
82+ Returns
83+ -------
84+ List[pd.DataFrame]
85+ A list of pandas DataFrames loaded from the files matching the extension.
86+ If a '.tsf' file is found, its converted DataFrame is returned immediately.
87+ """
5388 list_df = []
5489 for folder , _ , files in os .walk (path ):
5590 for file in files :
@@ -61,7 +96,37 @@ def get_dataframes_in_folder(path: str, extension: str) -> List[pd.DataFrame]:
6196 return list_df
6297
6398
64- def generate_artificial_ts (n_samples , periods , amp_anomalies , ratio_anomalies , amp_noise ):
99+ def generate_artificial_ts (
100+ n_samples : int ,
101+ periods : List [int ],
102+ amp_anomalies : float ,
103+ ratio_anomalies : float ,
104+ amp_noise : float ,
105+ ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
106+ """
107+ Generates time series data, anomalies, and noise based on given parameters.
108+
109+ Parameters
110+ ----------
111+ n_samples : int
112+ Number of samples in the time series.
113+ periods : List[int]
114+ List of periods that are added to the time series.
115+ amp_anomalies : float
116+ Amplitude multiplier for anomalies.
117+ ratio_anomalies : float
118+ Ratio of total samples that will be anomalies.
119+ amp_noise : float
120+ Standard deviation of Gaussian noise.
121+
122+ Returns
123+ -------
124+ Tuple[np.ndarray, np.ndarray, np.ndarray]
125+ Time series data with sine waves (X).
126+ Anomaly data with specified amplitudes at random positions (A).
127+ Gaussian noise added to the time series (E).
128+ """
129+
65130 mesh = np .arange (n_samples )
66131 X = np .ones (n_samples )
67132 for p in periods :
@@ -83,7 +148,8 @@ def get_data(
83148 datapath : str = "data/" ,
84149 n_groups_max : int = sys .maxsize ,
85150) -> pd .DataFrame :
86- """Download or generate data
151+ """
152+ Download or generate data
87153
88154 Parameters
89155 ----------
@@ -102,39 +168,16 @@ def get_data(
102168 if name_data == "Beijing" :
103169 df = read_csv_local ("beijing" )
104170 df ["date" ] = pd .to_datetime (df ["date" ])
105-
106- # df["date"] = pd.to_datetime(
107- # {
108- # "year": df["year"],
109- # "month": df["month"],
110- # "day": df["day"],
111- # "hour": df["hour"],
112- # }
113- # )
114171 df = df .drop (columns = ["year" , "month" , "day" , "hour" , "wd" ])
115- # df = df.set_index(["station", "date"])
116172 df = df .groupby (["station" , "date" ]).mean ()
117173 return df
118174 elif name_data == "Superconductor" :
119175 df = read_csv_local ("conductors" )
120176 return df
121177 elif name_data == "Titanic" :
122- # df = read_csv_local("titanic", sep=";")
123178 path = "https://gist.githubusercontent.com/fyyying/4aa5b471860321d7b47fd881898162b7/raw/"
124179 "6907bb3a38bfbb6fccf3a8b1edfb90e39714d14f/titanic_dataset.csv"
125180 df = pd .read_csv (path )
126- # df = df.dropna(how="all")
127- # df = df.drop(
128- # columns=[
129- # "pclass",
130- # "name",
131- # "home.dest",
132- # "cabin",
133- # "ticket",
134- # "boat",
135- # "body",
136- # ]
137- # )
138181 df = df [["Survived" , "Sex" , "Age" , "SibSp" , "Parch" , "Fare" , "Embarked" ]]
139182 df ["Age" ] = pd .to_numeric (df ["Age" ], errors = "coerce" )
140183 df ["Fare" ] = pd .to_numeric (df ["Fare" ], errors = "coerce" )
@@ -276,22 +319,16 @@ def add_holes(df: pd.DataFrame, ratio_masked: float, mean_size: int) -> pd.DataF
276319
277320 ratio_masked : float
278321 Targeted global proportion of nans added in the returned dataset
279-
280- groups: list of strings
281- List of the column names used as groups
282-
283322 Returns
284323 -------
285324 pd.DataFrame
286325 dataframe with missing values
287326 """
288- try :
289- groups = df . index . names . difference ([ "datetime" , "date" , "index" ])
327+ groups = df . index . names . difference ([ "datetime" , "date" , "index" ])
328+ if groups != []:
290329 generator = missing_patterns .GeometricHoleGenerator (
291330 1 , ratio_masked = ratio_masked , subset = df .columns , groups = groups
292331 )
293- except ValueError :
294- print ("No group" )
295332 else :
296333 generator = missing_patterns .GeometricHoleGenerator (
297334 1 , ratio_masked = ratio_masked , subset = df .columns
@@ -392,42 +429,27 @@ def convert_tsf_to_dataframe(
392429 col_types = []
393430 all_data = {}
394431 line_count = 0
395- # frequency = None
396- # forecast_horizon = None
397- # contain_missing_values = None
398- # contain_equal_length = None
399432 found_data_tag = False
400433 found_data_section = False
401434 started_reading_data_section = False
402435
403436 with open (full_file_path_and_name , "r" , encoding = "cp1252" ) as file :
404437 for line in file :
405- # Strip white space from start/end of line
406438 line = line .strip ()
407439
408440 if line :
409- if line .startswith ("@" ): # Read meta-data
441+ if line .startswith ("@" ):
410442 if not line .startswith ("@data" ):
411443 line_content = line .split (" " )
412444 if line .startswith ("@attribute" ):
413- if len (line_content ) != 3 : # Attributes have both name and type
445+ if len (line_content ) != 3 :
414446 raise Exception ("Invalid meta-data specification." )
415447
416448 col_names .append (line_content [1 ])
417449 col_types .append (line_content [2 ])
418450 else :
419- if len (line_content ) != 2 : # Other meta-data have only values
451+ if len (line_content ) != 2 :
420452 raise Exception ("Invalid meta-data specification." )
421-
422- # if line.startswith("@frequency"):
423- # frequency = line_content[1]
424- # elif line.startswith("@horizon"):
425- # forecast_horizon = int(line_content[1])
426- # elif line.startswith("@missing"):
427- # contain_missing_values = bool(strtobool(line_content[1]))
428- # elif line.startswith("@equallength"):
429- # contain_equal_length = bool(strtobool(line_content[1]))
430-
431453 else :
432454 if len (col_names ) == 0 :
433455 raise Exception ("Attribute section must come before data." )
0 commit comments