-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmutable_dataframe.py
More file actions
236 lines (200 loc) · 8.55 KB
/
mutable_dataframe.py
File metadata and controls
236 lines (200 loc) · 8.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import re
from collections.abc import Generator
from typing import Any
import polars as pl
from dapla_pseudo.v1.models.core import PseudoFunction
from dapla_pseudo.v1.models.core import PseudoRule
ARRAY_INDEX_MATCHER = re.compile(r"\[\d*]")
def _ensure_normalized(pattern: str) -> str:
"""Normalize the pattern.
Ensure that the pattern always starts with a '/' or '*' to be compatible with the
pattern matching that is used in pseudo-service.
"""
return (
pattern
if (pattern.startswith("/") or pattern.startswith("*"))
else "/" + pattern
)
class FieldMatch:
"""Represents a reference to a matching column in the dataframe."""
def __init__(
self,
path: str,
pattern: str,
indexer: list[str | int],
col: list[Any],
wrapped_list: bool,
func: PseudoFunction, # "source func" if repseudo
target_func: PseudoFunction | None, # "target_func" if repseudo, else None
) -> None:
"""Initialize the class."""
self.path = path
self.pattern = _ensure_normalized(pattern)
self.indexer = indexer
self.col = col
self.wrapped_list = wrapped_list
self.func = func
self.target_func = target_func
def get_value(self) -> list[str | int | None]:
"""Get the inner value.
If hierarchical, get the values of the matched column.
Otherwise, just return the data of the Polars DataFrame.
"""
return self.col
class MutableDataFrame:
"""A DataFrame that can change values in-place."""
def __init__(
self, dataframe: pl.DataFrame | pl.LazyFrame, hierarchical: bool
) -> None:
"""Initialize the class."""
self.dataset: pl.DataFrame | dict[str, Any] | pl.LazyFrame = dataframe
self.matched_fields: dict[str, FieldMatch] = {}
self.matched_fields_metrics: dict[str, int] | None = None
self.hierarchical: bool = hierarchical
self.schema = (
dataframe.schema
if isinstance(dataframe, pl.DataFrame)
else dataframe.collect_schema()
)
def match_rules(
self, rules: list[PseudoRule], target_rules: list[PseudoRule] | None
) -> None:
"""Create references to all the columns that matches the given pseudo rules."""
if self.hierarchical is False:
assert isinstance(self.dataset, pl.DataFrame) or isinstance(
self.dataset, pl.LazyFrame
)
def extract_column_data(
pattern: str, dataset: pl.DataFrame | pl.LazyFrame
) -> list[Any]:
if isinstance(dataset, pl.DataFrame):
return list(dataset.get_column(pattern))
elif isinstance(dataset, pl.LazyFrame):
return list(dataset.select(pattern).collect().to_series())
self.matched_fields = {
str(i): FieldMatch(
path=rule.pattern,
pattern=rule.pattern,
indexer=[],
col=extract_column_data(rule.pattern, self.dataset),
wrapped_list=False,
func=rule.func,
target_func=target_rule.func if target_rule else None,
)
for (i, (rule, target_rule)) in enumerate(
_combine_rules(rules, target_rules)
)
}
else:
assert isinstance(self.dataset, pl.DataFrame)
self.dataset = self.dataset.to_dict(as_series=False)
assert isinstance(self.dataset, dict)
for source_rule, target_rule in _combine_rules(rules, target_rules):
if source_rule.path is None:
raise ValueError(
f"Rule: {source_rule}\n does not have a concrete path, and cannot be used."
)
matches = _search_nested_path(
self.dataset,
source_rule.path,
(source_rule, target_rule),
)
for match in matches:
self.matched_fields[match.path] = match
def get_matched_fields(self) -> dict[str, FieldMatch]:
"""Get a reference to all the columns that matched pseudo rules."""
return self.matched_fields
def update(self, path: str, data: list[str | None]) -> None:
"""Update a column with the given data."""
if self.hierarchical is False:
assert isinstance(self.dataset, pl.DataFrame) or isinstance(
self.dataset, pl.LazyFrame
)
self.dataset = self.dataset.with_columns(pl.Series(data).alias(path))
elif (field_match := self.matched_fields.get(path)) is not None:
assert isinstance(self.dataset, dict)
tree = self.dataset
leaf_key = field_match.indexer[-1] # Either a dict key or a list index
for idx in field_match.indexer[:-1]:
tree = tree[idx] # type: ignore[index]
tree[leaf_key] = ( # type: ignore[index]
data if field_match.wrapped_list is False else data[0]
)
def to_polars(self) -> pl.DataFrame | pl.LazyFrame:
"""Convert to Polars DataFrame."""
if self.hierarchical is False:
assert isinstance(self.dataset, pl.DataFrame) or isinstance(
self.dataset, pl.LazyFrame
)
return self.dataset
else:
assert isinstance(self.dataset, dict)
return pl.from_dict(self.dataset, schema_overrides=self.schema)
def _combine_rules(
rules: list[PseudoRule], target_rules: list[PseudoRule] | None
) -> list[tuple[PseudoRule, PseudoRule | None]]:
combined: list[tuple[PseudoRule, PseudoRule | None]] = []
# Zip rules and target_rules together; use None as target if target_rules is undefined
for index, rule in enumerate(rules):
combined.append((rule, target_rules[index] if target_rules else None))
return combined
def _search_nested_path(
data: dict[str, Any] | list[Any],
path: str,
rules: tuple[PseudoRule, PseudoRule | None],
) -> Generator[FieldMatch, None, None]:
"""Search in the hierarchical data structure for the data at a given path.
Args:
data: The hierarchical data structure to search.
path: The path to search for in the data structure.
rules: The pseudo rules for the path.
Yields:
Generator[FieldMatch, None, None]: A generator yielding FieldMatch objects.
"""
keys = path.strip("/").split("/")
def _search(
current_tree: dict[str, Any] | list[Any] | str | None,
remaining_keys: list[str],
rules: tuple[PseudoRule, PseudoRule | None],
indexer: list[str | int],
curr_path: list[str],
) -> Generator[FieldMatch, None, None]:
if current_tree is None:
return
if not remaining_keys: # Base case: No more keys to process, reached leaf node
rule, target_rules = rules
# If the current value is not a list, we need to wrap it in a list
# in order to send it to pseudo-service.
# We record whether the value was a wrapped list primitive or an
# actual list value so we can unwrap it later when updating the data.
wrap_in_list = isinstance(current_tree, list) is False
yield FieldMatch(
path="/".join(curr_path),
pattern=rule.pattern,
indexer=indexer,
col=([current_tree] if wrap_in_list else current_tree), # type: ignore[arg-type]
wrapped_list=wrap_in_list,
func=rule.func,
target_func=target_rules.func if target_rules else None,
)
return
key = remaining_keys[0]
if isinstance(current_tree, dict): # Recursive case: Traverse dictionary
if key in current_tree:
yield from _search(
current_tree[key],
remaining_keys[1:],
rules,
[*indexer, key],
[*curr_path, key],
)
elif isinstance(current_tree, list): # Recursive case: Traverse list
for idx, item in enumerate(current_tree):
yield from _search(
item,
remaining_keys,
rules,
[*indexer, idx],
[*curr_path[:-1], f"{curr_path[-1]}[{idx}]"],
)
yield from _search(data, keys, rules, [], [])