Skip to content

Commit ed6a3ca

Browse files
authored
Add support for instances ids (#180)
* fixes * update pre-commit * changelog
1 parent 41539cb commit ed6a3ca

File tree

7 files changed

+44
-19
lines changed

7 files changed

+44
-19
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
exclude: ^LICENSE/|\.(html|csv|svg|md)$
2-
default_stages: [commit]
2+
default_stages: [pre-commit, commit, pre-push]
33
repos:
44
- repo: https://github.com/pre-commit/pre-commit-hooks
55
rev: v4.5.0

docs/changelog.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ v0.8.0 (December X, X)
88
* Update LLM helper to support updated GPT-4 models [#174][#174]
99
* Update ruff to latest and remove black as a development dependency [#174][#174]
1010
* Add Python 3.11 markers and CI testing [#174][#174]
11+
* Add support for instance IDs when generating target values [#180][#180]
1112
* Fixes
12-
*
13+
* Fix verbose print out during target value generation [#180][#180]
1314

1415
[#174]: <https://github.com/trane-dev/Trane/pull/174>
16+
[#180]: <https://github.com/trane-dev/Trane/pull/180>
1517

1618
v0.7.0 (October 21, 2023)
1719
=========================

tests/test_problem_generator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ def test_problem_generator_single_table():
3636
# 3. Generate target values for each problem
3737
for p in problems:
3838
if p.has_parameters_set() is True:
39-
labels = p.create_target_values(dataframe)
39+
labels = p.create_target_values(dataframe, verbose=False)
4040
if labels.empty:
4141
raise ValueError("labels should not be empty")
4242
check_problem_type(labels, p.get_problem_type())
4343
else:
4444
thresholds = p.get_recommended_thresholds(dataframe)
4545
for threshold in thresholds:
4646
p.set_parameters(threshold)
47-
labels = p.create_target_values(dataframe)
47+
labels = p.create_target_values(dataframe, verbose=False)
4848
check_problem_type(labels, p.get_problem_type())
4949

5050

@@ -84,13 +84,13 @@ def test_problem_generator_multi(tables, target_table):
8484
string_repr = p.__repr__()
8585
assert "2 days" in string_repr
8686
if p.has_parameters_set() is True:
87-
labels = p.create_target_values(dataframes)
87+
labels = p.create_target_values(dataframes, verbose=False)
8888
check_problem_type(labels, p.get_problem_type())
8989
else:
9090
thresholds = p.get_recommended_thresholds(dataframes)
9191
for threshold in thresholds:
9292
p.set_parameters(threshold)
93-
labels = p.create_target_values(dataframes)
93+
labels = p.create_target_values(dataframes, verbose=False)
9494
check_problem_type(labels, p.get_problem_type())
9595

9696

trane/core/problem.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,13 @@ def get_recommended_thresholds(self, dataframes, n_quantiles=10):
127127
)
128128
return thresholds
129129

130-
def create_target_values(self, dataframes, verbose=False):
130+
def create_target_values(
131+
self,
132+
dataframes,
133+
verbose=False,
134+
nrows=None,
135+
instance_ids=None,
136+
):
131137
# Won't this always be normalized?
132138
normalized_dataframe = self.get_normalized_dataframe(dataframes)
133139
if self.has_parameters_set() is False:
@@ -141,6 +147,12 @@ def create_target_values(self, dataframes, verbose=False):
141147
# create a fake index with all rows to generate predictions problems "Predict X"
142148
normalized_dataframe["__identity__"] = 0
143149
target_dataframe_index = "__identity__"
150+
if instance_ids and len(instance_ids) > 0:
151+
if verbose:
152+
print("Only selecting given instance IDs")
153+
normalized_dataframe = normalized_dataframe[
154+
normalized_dataframe[self.entity_column].isin(instance_ids)
155+
]
144156

145157
lt = calculate_target_values(
146158
df=normalized_dataframe,
@@ -149,6 +161,7 @@ def create_target_values(self, dataframes, verbose=False):
149161
time_index=self.metadata.time_index,
150162
window_size=self.window_size,
151163
verbose=verbose,
164+
nrows=nrows,
152165
)
153166
if "__identity__" in normalized_dataframe.columns:
154167
normalized_dataframe.drop(columns=["__identity__"], inplace=True)

trane/core/utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import pandas as pd
22

33

4-
def set_dataframe_index(df, index):
4+
def set_dataframe_index(df, index, verbose=False):
55
if df.index.name != index:
6+
if verbose:
7+
print(f"setting dataframe to index: {index}")
68
df = df.set_index(index, inplace=False)
79
return df
810

911

10-
def generate_data_slices(df, window_size, gap, drop_empty=True):
12+
def generate_data_slices(df, window_size, gap, drop_empty=True, verbose=False):
1113
# valid for a specify group of id
1214
# so we need to groupby id (before this function)
1315
window_size = pd.to_timedelta(window_size)
@@ -36,18 +38,25 @@ def calculate_target_values(
3638
window_size,
3739
drop_empty=True,
3840
verbose=False,
41+
nrows=None,
3942
):
40-
df = set_dataframe_index(df, time_index)
43+
df = set_dataframe_index(df, time_index, verbose=verbose)
44+
if str(df.index.dtype) == "timestamp[ns][pyarrow]":
45+
df.index = df.index.astype("datetime64[ns]")
46+
if nrows and nrows > 0 and nrows < len(df):
47+
if verbose:
48+
print("sampling {nrows} rows")
49+
df = df.sample(n=nrows)
4150
records = []
4251
label_name = labeling_function.__name__
43-
4452
for group_key, df_by_index in df.groupby(target_dataframe_index, observed=True):
4553
# TODO: support gap
4654
for dataslice, _ in generate_data_slices(
4755
df=df_by_index,
4856
window_size=window_size,
4957
gap=window_size,
5058
drop_empty=drop_empty,
59+
verbose=verbose,
5160
):
5261
record = labeling_function(dataslice)
5362
records.append(

trane/metadata/metadata.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,6 @@ def reset_primary_key(self, table):
164164
self.check_if_table_exists(table)
165165
self.primary_keys.pop(table, None)
166166

167-
def obi(self, table):
168-
self.check_if_table_exists(table)
169-
if self.primary_keys:
170-
primary_key = self.primary_keys[table]
171-
self.ml_types[table][primary_key].remove_tag("primary_key")
172-
self.primary_keys.pop(table)
173-
174167
def add_table(self, table, ml_types):
175168
if table in self.ml_types:
176169
raise ValueError("Table already exists")

trane/parsing/denormalize.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def denormalize_dataframes(
4747
ml_types: Dict[str, Dict[str, str]],
4848
target_table: str,
4949
) -> pd.DataFrame:
50+
keys_to_ml_type = {}
5051
merged_dataframes = {}
5152
for relationship in relationships:
5253
parent_table_name, parent_key, child_table_name, child_key = relationship
@@ -58,6 +59,8 @@ def denormalize_dataframes(
5859
raise ValueError(
5960
f"{child_key} not in table: {child_table_name}",
6061
)
62+
keys_to_ml_type[parent_key] = ml_types.get(parent_table_name).get(parent_key)
63+
keys_to_ml_type[child_key] = ml_types.get(child_table_name).get(child_key)
6164
check_target_table(target_table, relationships, list(dataframes.keys()))
6265
relationship_order = child_relationships(target_table, relationships)
6366
if len(relationship_order) == 0:
@@ -110,7 +113,12 @@ def denormalize_dataframes(
110113
# TODO: set primary key to be the index
111114
# TODO: pass information to table meta (primary key, foreign keys)? maybe? technically relationships has this info
112115
valid_columns = list(merged_dataframes[target_table].columns)
113-
col_to_ml_type = {col: column_to_ml_type[col] for col in valid_columns}
116+
col_to_ml_type = {}
117+
for col in valid_columns:
118+
if col in column_to_ml_type:
119+
col_to_ml_type[col] = column_to_ml_type[col]
120+
else:
121+
col_to_ml_type[col] = keys_to_ml_type[col]
114122
return merged_dataframes, col_to_ml_type
115123

116124

0 commit comments

Comments
 (0)