Skip to content

Commit 9062eed

Browse files
committed
feat: add save_target_stats method to Artifacts for saving target class statistics
1 parent 7e3fb25 commit 9062eed

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

source/analysis/artifacts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,5 +188,11 @@ def load_spectral_bands(self) -> Optional[np.ndarray]:
188188
logger.warning(f"Could not load spectral bands: {e}")
189189
return None
190190

191+
def save_target_stats(self, target: np.ndarray):
192+
save_path = self._set_save_path(RESULTS)
193+
unique, counts = np.unique(target, return_counts=True)
194+
target_stats = dict(zip(unique.tolist(), counts.tolist()))
195+
write_json(target_stats, save_path / "target_stats.json")
196+
191197

192198
artifacts = Artifacts()

source/cli/analysis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def generate_metrics(
8888

8989
metrics_ = metrics.calculate_metrics(model_, encoder, X, y)
9090
artifacts.save_metrics(metrics_)
91+
# artifacts.save_target_stats(y)
9192

9293

9394
@app.command()

0 commit comments

Comments
 (0)