Skip to content

Commit 7e3fb25

Browse files
committed
feat: enhance load_dataframe and add load_data_balanced for improved data handling and class balancing
1 parent 258321c commit 7e3fb25

File tree

1 file changed

+189
-35
lines changed

1 file changed

+189
-35
lines changed

extensions/data_loaders/ex_crp.py

Lines changed: 189 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
CSV_DATA_PATH = Path("/home/janezla/data/crp")
99

1010

11-
def load_dataframe() -> pd.DataFrame:
11+
def load_dataframe(date: str, gerk: str) -> pd.DataFrame:
1212
csv_files = list(CSV_DATA_PATH.glob("*.csv"))
1313
if not csv_files:
1414
raise FileNotFoundError(f"No CSV files found in {CSV_DATA_PATH}")
@@ -43,42 +43,196 @@ def load_dataframe() -> pd.DataFrame:
4343
]
4444
# Remove date with zero values
4545
# combined_df = combined_df[combined_df["date"] != "2023_09_20"]
46-
# combined_df = combined_df[combined_df["date"] == "2024_09_23"]
47-
# combined_df = combined_df[combined_df["gerk"] == "6006"
46+
combined_df = combined_df[combined_df["date"] == date]
47+
combined_df = combined_df[combined_df["gerk"] == gerk]
4848
return combined_df
4949

5050

51-
class CrpLoader(BaseDataLoader):
51+
def load_data_balanced(date: str, gerk: str) -> tuple[np.ndarray, np.ndarray]:
52+
combined_df = load_dataframe(date, gerk)
53+
# Create date_gerk_st grouping column
54+
combined_df["date_gerk_st"] = combined_df[["date", "gerk", "st"]].apply(
55+
lambda x: f"{x['date']}_{x['gerk']}_{x['st']}", axis=1
56+
)
57+
# Get unique groups and their counts
58+
groups, counts = np.unique(combined_df["date_gerk_st"], return_counts=True)
59+
60+
# Balance classes by sampling equal amounts from each group
61+
min_count = np.min(counts)
62+
balanced_dfs = []
63+
64+
for group in groups:
65+
group_df = combined_df[combined_df["date_gerk_st"] == group]
66+
# Sample min_count samples from each group (or all if less than min_count)
67+
sample_size = min(len(group_df), min_count)
68+
sampled_df = group_df.sample(n=sample_size, random_state=42)
69+
balanced_dfs.append(sampled_df)
70+
71+
# Combine balanced samples
72+
balanced_df = pd.concat(balanced_dfs, ignore_index=True)
73+
74+
# Extract features and labels
75+
X = balanced_df.drop(
76+
columns=["date", "GERK_ID_St", "gerk", "id", "st", "date_gerk_st"]
77+
).values
78+
y = np.array(balanced_df["st"].values)
79+
logger.info(
80+
f"Balanced dataset: {len(X)} samples from {len(groups)} groups (min count per group: {min_count})"
81+
)
82+
return X, y
83+
84+
85+
class CrpLoader2023_06_26_1528147(BaseDataLoader):
5286
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
53-
combined_df = load_dataframe()
54-
# Create date_gerk_st grouping column
55-
combined_df["date_gerk_st"] = combined_df[["date", "gerk", "st"]].apply(
56-
lambda x: f"{x['date']}_{x['gerk']}_{x['st']}", axis=1
57-
)
58-
# Get unique groups and their counts
59-
groups, counts = np.unique(combined_df["date_gerk_st"], return_counts=True)
60-
61-
# Balance classes by sampling equal amounts from each group
62-
min_count = np.min(counts)
63-
balanced_dfs = []
64-
65-
for group in groups:
66-
group_df = combined_df[combined_df["date_gerk_st"] == group]
67-
# Sample min_count samples from each group (or all if less than min_count)
68-
sample_size = min(len(group_df), min_count)
69-
sampled_df = group_df.sample(n=sample_size, random_state=42)
70-
balanced_dfs.append(sampled_df)
71-
72-
# Combine balanced samples
73-
balanced_df = pd.concat(balanced_dfs, ignore_index=True)
74-
75-
# Extract features and labels
76-
X = balanced_df.drop(
77-
columns=["date", "GERK_ID_St", "gerk", "id", "st", "date_gerk_st"]
78-
).values
79-
y = np.array(balanced_df["st"].values)
80-
81-
logger.info(
82-
f"Balanced dataset: {len(X)} samples from {len(groups)} groups (min count per group: {min_count})"
83-
)
87+
X, y = load_data_balanced(date="2023_06_26", gerk="1528147")
88+
return self._shuffle_data(X, y)
89+
90+
91+
class CrpLoader2023_06_26_174223(BaseDataLoader):
92+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
93+
X, y = load_data_balanced(date="2023_06_26", gerk="174223")
94+
return self._shuffle_data(X, y)
95+
96+
97+
class CrpLoader2023_06_26_2119192(BaseDataLoader):
98+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
99+
X, y = load_data_balanced(date="2023_06_26", gerk="2119192")
100+
return self._shuffle_data(X, y)
101+
102+
103+
class CrpLoader2023_06_26_3266572(BaseDataLoader):
104+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
105+
X, y = load_data_balanced(date="2023_06_26", gerk="3266572")
106+
return self._shuffle_data(X, y)
107+
108+
109+
class CrpLoader2023_06_26_4606283(BaseDataLoader):
110+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
111+
X, y = load_data_balanced(date="2023_06_26", gerk="4606283")
112+
return self._shuffle_data(X, y)
113+
114+
115+
class CrpLoader2023_06_26_5099541(BaseDataLoader):
116+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
117+
X, y = load_data_balanced(date="2023_06_26", gerk="5099541")
118+
return self._shuffle_data(X, y)
119+
120+
121+
class CrpLoader2023_06_26_6006(BaseDataLoader):
122+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
123+
X, y = load_data_balanced(date="2023_06_26", gerk="6006")
124+
return self._shuffle_data(X, y)
125+
126+
127+
class CrpLoader2023_09_20_1528147(BaseDataLoader):
128+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
129+
X, y = load_data_balanced(date="2023_09_20", gerk="1528147")
130+
return self._shuffle_data(X, y)
131+
132+
133+
class CrpLoader2023_09_20_174223(BaseDataLoader):
134+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
135+
X, y = load_data_balanced(date="2023_09_20", gerk="174223")
136+
return self._shuffle_data(X, y)
137+
138+
139+
class CrpLoader2023_09_20_2119192(BaseDataLoader):
140+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
141+
X, y = load_data_balanced(date="2023_09_20", gerk="2119192")
142+
return self._shuffle_data(X, y)
143+
144+
145+
class CrpLoader2023_09_20_3266572(BaseDataLoader):
146+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
147+
X, y = load_data_balanced(date="2023_09_20", gerk="3266572")
148+
return self._shuffle_data(X, y)
149+
150+
151+
class CrpLoader2023_09_20_4606283(BaseDataLoader):
152+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
153+
X, y = load_data_balanced(date="2023_09_20", gerk="4606283")
154+
return self._shuffle_data(X, y)
155+
156+
157+
class CrpLoader2023_09_20_5099541(BaseDataLoader):
158+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
159+
X, y = load_data_balanced(date="2023_09_20", gerk="5099541")
160+
return self._shuffle_data(X, y)
161+
162+
163+
class CrpLoader2023_09_20_6006(BaseDataLoader):
164+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
165+
X, y = load_data_balanced(date="2023_09_20", gerk="6006")
166+
return self._shuffle_data(X, y)
167+
168+
169+
class CrpLoader2024_07_10_1528147(BaseDataLoader):
170+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
171+
X, y = load_data_balanced(date="2024_07_10", gerk="1528147")
172+
return self._shuffle_data(X, y)
173+
174+
175+
class CrpLoader2024_07_10_2119192(BaseDataLoader):
176+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
177+
X, y = load_data_balanced(date="2024_07_10", gerk="2119192")
178+
return self._shuffle_data(X, y)
179+
180+
181+
class CrpLoader2024_07_10_4606283(BaseDataLoader):
182+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
183+
X, y = load_data_balanced(date="2024_07_10", gerk="4606283")
184+
return self._shuffle_data(X, y)
185+
186+
187+
class CrpLoader2024_07_10_5099541(BaseDataLoader):
188+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
189+
X, y = load_data_balanced(date="2024_07_10", gerk="5099541")
190+
return self._shuffle_data(X, y)
191+
192+
193+
class CrpLoader2024_07_10_6006(BaseDataLoader):
194+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
195+
X, y = load_data_balanced(date="2024_07_10", gerk="6006")
196+
return self._shuffle_data(X, y)
197+
198+
199+
class CrpLoader2024_09_23_1528147(BaseDataLoader):
200+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
201+
X, y = load_data_balanced(date="2024_09_23", gerk="1528147")
202+
return self._shuffle_data(X, y)
203+
204+
205+
class CrpLoader2024_09_23_174223(BaseDataLoader):
206+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
207+
X, y = load_data_balanced(date="2024_09_23", gerk="174223")
208+
return self._shuffle_data(X, y)
209+
210+
211+
class CrpLoader2024_09_23_2119192(BaseDataLoader):
212+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
213+
X, y = load_data_balanced(date="2024_09_23", gerk="2119192")
214+
return self._shuffle_data(X, y)
215+
216+
217+
class CrpLoader2024_09_23_3266572(BaseDataLoader):
218+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
219+
X, y = load_data_balanced(date="2024_09_23", gerk="3266572")
220+
return self._shuffle_data(X, y)
221+
222+
223+
class CrpLoader2024_09_23_4606283(BaseDataLoader):
224+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
225+
X, y = load_data_balanced(date="2024_09_23", gerk="4606283")
226+
return self._shuffle_data(X, y)
227+
228+
229+
class CrpLoader2024_09_23_5099541(BaseDataLoader):
230+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
231+
X, y = load_data_balanced(date="2024_09_23", gerk="5099541")
232+
return self._shuffle_data(X, y)
233+
234+
235+
class CrpLoader2024_09_23_6006(BaseDataLoader):
236+
def load_data(self) -> tuple[np.ndarray, np.ndarray]:
237+
X, y = load_data_balanced(date="2024_09_23", gerk="6006")
84238
return self._shuffle_data(X, y)

0 commit comments

Comments
 (0)