-
Notifications
You must be signed in to change notification settings - Fork 69
Expand file tree
/
Copy pathdataset_explorer.py
More file actions
308 lines (259 loc) · 11.2 KB
/
dataset_explorer.py
File metadata and controls
308 lines (259 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
"""Dataset Explorer to summarize datasets stored in S3 buckets."""
import warnings
from collections import defaultdict
from pathlib import Path
from urllib.parse import urlparse
import pandas as pd
from sdv.metadata import Metadata
from sdgym.datasets import BUCKET, _get_available_datasets, _validate_modality, load_dataset
SUMMARY_OUTPUT_COLUMNS = [
'Dataset',
'Datasize_Size_MB',
'Num_Tables',
'Total_Num_Columns',
'Total_Num_Columns_Categorical',
'Total_Num_Columns_Numerical',
'Total_Num_Columns_Datetime',
'Total_Num_Columns_PII',
'Total_Num_Columns_ID_NonKey',
'Max_Num_Columns_Per_Table',
'Total_Num_Rows',
'Max_Num_Rows_Per_Table',
'Num_Relationships',
'Max_Schema_Depth',
'Max_Schema_Branch',
]
class DatasetExplorer:
"""``DatasetExplorer`` class.
This class provides utilities to analyze datasets hosted on S3 by loading
their metadata and data, computing schema and data summaries, and optionally
saving the results as a CSV file.
Args:
s3_url (str, optional):
The base S3 bucket URL containing the datasets. Defaults to `s3://sdv-datasets-public`.
aws_access_key_id (str, optional):
AWS access key ID for authentication. Defaults to ``None``.
aws_secret_access_key (str, optional):
AWS secret access key for authentication. Defaults to ``Non``.
"""
def __init__(self, s3_url=BUCKET, aws_access_key_id=None, aws_secret_access_key=None):
self.s3_url = s3_url
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self._bucket_name = urlparse(self.s3_url).netloc
@staticmethod
def _get_max_schema_branch_factor(relationships):
"""Compute the maximum number of child tables branching from any parent table.
Args:
relationships (list[dict]):
A list of relationship dictionaries describing parent-child table relationships.
Returns:
int:
The maximum number of children linked to a single parent table.
"""
branch_counts = defaultdict(int)
for rel in relationships:
parent = rel['parent_table_name']
branch_counts[parent] += 1
return max((value for value in branch_counts.values()), default=0)
@staticmethod
def _get_max_depth(metadata):
"""Calculate the maximum depth of a metadata.
Args:
metadata (sdv.metadata.Metadata):
The SDV Metadata object representing the dataset.
Returns:
int:
The maximum schema depth (i.e., the longest parent-child relationship chain).
"""
child_map = metadata._get_child_map()
parent_map = metadata._get_parent_map()
if not any(child_map.values()):
return 1
def dfs(table):
if table not in child_map or not child_map[table]:
return 1
return 1 + max(dfs(child) for child in child_map[table])
parent_map = metadata._get_parent_map()
root_tables = [table for table in child_map.keys() if table not in parent_map]
return max(dfs(root) for root in root_tables)
@staticmethod
def _summarize_metadata_columns(metadata):
"""Summarize column-level details from a dataset’s metadata.
Args:
metadata (sdv.metadata.Metadata):
The SDV Metadata object containing table and column information.
Returns:
dict:
A dictionary summarizing total and per-type column counts across all tables.
"""
results = {
'Total_Num_Columns': 0,
'Total_Num_Columns_Categorical': 0,
'Total_Num_Columns_Numerical': 0,
'Total_Num_Columns_Datetime': 0,
'Total_Num_Columns_PII': 0,
'Total_Num_Columns_ID_NonKey': 0,
'Max_Num_Columns_Per_Table': 0,
}
for table_name, table in metadata.tables.items():
num_cols = len(table.columns)
keys = [table.primary_key, table.sequence_key, table.sequence_index]
if isinstance(table.alternate_keys, list):
keys += table.alternate_keys
results['Total_Num_Columns'] += num_cols
results['Max_Num_Columns_Per_Table'] = max(
results['Max_Num_Columns_Per_Table'], num_cols
)
for column_name, column in table.columns.items():
sdtype = column['sdtype']
if sdtype in ['categorical', 'boolean']:
results['Total_Num_Columns_Categorical'] += 1
elif sdtype in ['numerical']:
results['Total_Num_Columns_Numerical'] += 1
elif sdtype in ['datetime']:
results['Total_Num_Columns_Datetime'] += 1
elif sdtype in ['id'] and column_name != table.primary_key:
results['Total_Num_Columns_ID_NonKey'] += 1
elif column_name == table.primary_key:
continue
else:
results['Total_Num_Columns_PII'] += 1
return results
@staticmethod
def get_metadata_summary(metadata):
"""Summarize schema-level information from dataset metadata.
Args:
metadata (dict or Metadata):
The dataset metadata as a dictionary or SDV Metadata object.
Returns:
dict:
A dictionary containing aggregated schema statistics such as number of
relationships, schema depth, branching factor, and column-type counts.
"""
if isinstance(metadata, dict):
metadata = Metadata.load_from_dict(metadata)
metadata_summary = DatasetExplorer._summarize_metadata_columns(metadata)
total_relationships = len(metadata.relationships)
max_schema_branch_factor = DatasetExplorer._get_max_schema_branch_factor(
metadata.relationships
)
metadata_summary.update({
'Num_Relationships': total_relationships,
'Max_Schema_Depth': DatasetExplorer._get_max_depth(metadata),
'Max_Schema_Branch': max_schema_branch_factor,
})
return metadata_summary
@staticmethod
def get_data_summary(data):
"""Summarize record-level information from dataset tables.
Args:
data (dict[str, pd.DataFrame] or pd.DataFrame):
The dataset data, either as a dictionary of table DataFrames or a single DataFrame.
Returns:
dict:
A dictionary summarizing total number of rows and maximum table size.
"""
data_dict = data if isinstance(data, dict) else {'dataset': data}
data_summary = {
'Total_Num_Rows': 0,
'Max_Num_Rows_Per_Table': 0,
}
for table_name, table in data_dict.items():
table_num_rows = len(table)
data_summary['Total_Num_Rows'] += table_num_rows
data_summary['Max_Num_Rows_Per_Table'] = max(
data_summary['Max_Num_Rows_Per_Table'], table_num_rows
)
return data_summary
def _load_and_summarize_datasets(self, modality):
"""Load all datasets for the given modality and compute summary statistics.
Args:
modality (str):
The dataset modality to load (e.g., 'single-table' or 'multi-table').
Returns:
list[dict]:
A list of dictionaries, each containing metadata and data summaries
for an individual dataset.
"""
results = []
datasets = _get_available_datasets(
modality=modality,
bucket=self._bucket_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
)
for _, dataset_row in datasets.iterrows():
dataset_name = dataset_row['dataset_name']
dataset_size_mb = dataset_row['size_MB']
dataset_num_table = dataset_row['num_tables']
data, metadata_dict = load_dataset(
modality,
dataset=dataset_name,
bucket=self._bucket_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
)
metadata_stats = DatasetExplorer.get_metadata_summary(metadata_dict)
data_stats = DatasetExplorer.get_data_summary(data)
max_schema_depth = metadata_stats.pop('Max_Schema_Depth')
max_schema_branch = metadata_stats.pop('Max_Schema_Branch')
num_relationships = metadata_stats.pop('Num_Relationships')
results.append({
'Dataset': dataset_name,
'Datasize_Size_MB': dataset_size_mb,
'Num_Tables': dataset_num_table,
**metadata_stats,
**data_stats,
'Num_Relationships': num_relationships,
'Max_Schema_Depth': max_schema_depth,
'Max_Schema_Branch': max_schema_branch,
})
return results
def _validate_output_filepath(self, output_filepath):
"""Validate that the provided output path has a .csv file extension.
Args:
output_filepath (str or None):
The file path to validate.
Raises:
ValueError:
If the provided path is not None and does not end with '.csv'.
"""
if output_filepath and not Path(output_filepath).suffix == '.csv':
raise ValueError(
f"The 'output_filepath' has to be a .csv file, provided: '{output_filepath}'."
)
def summarize_datasets(self, modality, output_filepath=None):
"""Load, summarize, and optionally export dataset statistics for a given modality.
Args:
modality (str):
It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
output_filepath (str, optional):
The path to save the summary as a CSV file. If `None`, results are returned only.
Returns:
pd.DataFrame:
A DataFrame containing aggregated dataset summaries including schema and
data-level statistics.
Raises:
ValueError:
If `output_filepath` is provided and does not have a '.csv' extension.
ValueError:
If the modality provided is not `single_table`, `multi_table` or `sequential`.
"""
self._validate_output_filepath(output_filepath)
_validate_modality(modality)
results = self._load_and_summarize_datasets(modality)
if not results:
warnings.warn(
(
f"The provided S3 URL '{self.s3_url}' does not contain any datasets "
f"of modality '{modality}'."
),
UserWarning,
)
dataset_summary = pd.DataFrame(columns=SUMMARY_OUTPUT_COLUMNS)
else:
dataset_summary = pd.DataFrame(results)
if output_filepath:
dataset_summary.to_csv(output_filepath, index=False)
return dataset_summary