Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions src/deepforest/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import json
import os
import shutil
import sys
import tempfile
import warnings
from pathlib import Path

Expand Down Expand Up @@ -61,6 +64,150 @@ def on_train_start(self, trainer, pl_module):
self.trainer = trainer
self.pl_module = pl_module

# --- Log up to 5 annotated training images (ground truth) before training starts ---
# This helps verify that the annotations align with the images. We render
# up to 5 non-empty annotated images to a temporary directory and then
# log them using the existing logger handling in _log_to_all().
try:
train_ds = trainer.train_dataloader.dataset

# Collect image names that actually have annotations
image_names = list(
getattr(train_ds, "image_names", train_ds.annotations.image_path.unique())
)
non_empty = []
for img_name in image_names:
try:
targets = train_ds.annotations_for_path(img_name)
except Exception:
# If annotations_for_path fails for any image, skip it
continue
if (
targets
and "boxes" in targets
and getattr(targets["boxes"], "shape", (0,))[0] > 0
):
non_empty.append(img_name)

if len(non_empty) > 0:
n = min(5, len(non_empty))
selected = np.random.choice(non_empty, size=n, replace=False)

tmpdir = tempfile.mkdtemp()
try:
for filename in selected:
# Subset annotations for the chosen image and ensure root_dir is set
sample_ann = train_ds.annotations[
train_ds.annotations.image_path == filename
].copy()
sample_ann.root_dir = train_ds.root_dir

# Plot and save annotated image(s) to the temporary directory
basename = Path(filename).stem
fig = visualize.plot_annotations(
annotations=sample_ann,
savedir=tmpdir,
basename=basename,
show=False,
)
plt.close(fig)

# Log image to available loggers (Comet/TensorBoard/etc.)
self._log_to_all(
image=os.path.join(tmpdir, basename + ".png"),
trainer=trainer,
tag="train_annotated_sample",
)
finally:
# Clean up temporary directory
try:
shutil.rmtree(tmpdir)
except Exception:
pass
except Exception as e:
# Don't fail training startup on logging issues; warn instead
warnings.warn(f"Could not log annotated training samples: {e}", stacklevel=2)

# Interactive image screen configuration (optional)
# Triggered only when the environment variable is set and stdin is a TTY,
# so automated runs are unaffected.
try:
if os.getenv("DEEPFOREST_IMAGE_SCREEN_PROMPT") == "1" and sys.stdin.isatty():
print(
"\nDeepForest image screen options:\n1) Default\n2) Arduino\n3) Custom"
)
choice = input("Select option [1-3] (enter to skip): ").strip()
if choice == "3":
# Supported sizes and drivers (kept small and common)
sizes = ["128x64", "240x135", "320x240", "480x320"]
drivers = ["SSD1306", "ST7735", "ILI9341", "Custom"]

print("\nSupported custom sizes:")
for i, s in enumerate(sizes, start=1):
print(f"{i}) {s}")
s_choice = input("Choose size [1-4] or enter custom WxH: ").strip()
if s_choice.isdigit() and 1 <= int(s_choice) <= len(sizes):
size_str = sizes[int(s_choice) - 1]
else:
size_str = s_choice

print("\nSupported drivers:")
for i, d in enumerate(drivers, start=1):
print(f"{i}) {d}")
d_choice = input("Choose driver [1-4] or enter custom name: ").strip()
if d_choice.isdigit() and 1 <= int(d_choice) <= len(drivers):
driver = drivers[int(d_choice) - 1]
else:
driver = d_choice

color_input = input(
"Enter color (named e.g. red, or hex #RRGGBB) or press enter for default: "
).strip()

# Normalize color to an RGB list if provided
def _parse_color(c: str):
if not c:
return None
c = c.strip()
# hex
if c.startswith("#") and len(c) == 7:
try:
return [int(c[i : i + 2], 16) for i in (1, 3, 5)]
except Exception:
return None
# simple names
name_map = {
"red": [255, 0, 0],
"green": [0, 255, 0],
"blue": [0, 0, 255],
"yellow": [255, 255, 0],
"orange": [255, 165, 0],
}
return name_map.get(c.lower())

color_rgb = _parse_color(color_input)

# Store configuration on the callback for later use
self.custom_screen = {
"size": tuple(int(x) for x in size_str.split("x"))
if "x" in size_str
else None,
"driver": driver,
"color": color_rgb,
}

# If a color was selected, apply it to the callback so plotting
# uses the chosen color for annotations.
if color_rgb is not None:
self.color = color_rgb

print(f"Custom screen configured: {self.custom_screen}\n")
except Exception as e:
# Never fail training startup for interactive configuration issues
warnings.warn(
f"Could not run interactive image screen config: {e}", stacklevel=2
)

# Training samples
pl_module.print("Logging training dataset samples.")
train_ds = trainer.train_dataloader.dataset
Expand Down