Skip to content

Commit 0341fb5

Browse files
Fix PyTorch 2.9 deprecation and add training-loop smoke tests
- Fix PyTorch 2.9 deprecation: use tuple indexing instead of list for multidimensional tensor indexing in SetCriterion (ia_bce / varifocal) - Add end-to-end Lightning training-loop smoke tests for DETR, Deformable DETR, and RF-DETR (train + val + predict on CPU with timm_resnet18 backbone and synthetic data) Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
1 parent e2b3919 commit 0341fb5

File tree

1 file changed

+253
-0
lines changed

1 file changed

+253
-0
lines changed

tests/test_detr.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import unittest
55
from unittest.mock import patch
66

7+
import lightning
78
import pytest
89
import torch
10+
from torch.utils.data import DataLoader, Dataset
911

1012
from terratorch.models.object_detection_model_factory import (
1113
ObjectDetectionModel,
1214
ObjectDetectionModelFactory,
1315
)
16+
from terratorch.tasks.object_detection_task import ObjectDetectionTask
1417

1518
# ---------------------------------------------------------------------------
1619
# Helpers
@@ -1296,5 +1299,255 @@ def test_eval_returns_masks(self):
12961299
gc.collect()
12971300

12981301

1302+
# ---------------------------------------------------------------------------
1303+
# Lightning training-loop smoke tests
1304+
# ---------------------------------------------------------------------------
1305+
1306+
1307+
class _TinyDetectionDataset(Dataset):
1308+
"""Random images with 1-3 random boxes per image."""
1309+
1310+
def __init__(self, size: int = 8, img_h: int = 128, img_w: int = 128, num_classes: int = 5):
1311+
self.size = size
1312+
self.img_h = img_h
1313+
self.img_w = img_w
1314+
self.num_classes = num_classes
1315+
1316+
def __len__(self):
1317+
return self.size
1318+
1319+
def __getitem__(self, idx):
1320+
image = torch.randn(3, self.img_h, self.img_w)
1321+
n = torch.randint(1, 4, (1,)).item()
1322+
x1 = torch.randint(0, self.img_w // 2, (n,)).float()
1323+
y1 = torch.randint(0, self.img_h // 2, (n,)).float()
1324+
x2 = (x1 + torch.randint(10, self.img_w // 2, (n,)).float()).clamp(max=self.img_w)
1325+
y2 = (y1 + torch.randint(10, self.img_h // 2, (n,)).float()).clamp(max=self.img_h)
1326+
boxes = torch.stack([x1, y1, x2, y2], dim=1)
1327+
labels = torch.randint(1, self.num_classes, (n,))
1328+
return {"image": image, "boxes": boxes, "labels": labels}
1329+
1330+
1331+
def _det_collate(batch):
1332+
images = torch.stack([b["image"] for b in batch])
1333+
boxes = [b["boxes"] for b in batch]
1334+
labels = [b["labels"] for b in batch]
1335+
return {"image": images, "boxes": boxes, "labels": labels}
1336+
1337+
1338+
def _make_task(framework, extra_model_args=None):
1339+
"""Build a minimal ObjectDetectionTask for a given framework."""
1340+
model_args = {
1341+
"framework": framework,
1342+
"backbone": "timm_resnet18",
1343+
"backbone_pretrained": False,
1344+
"num_classes": 5,
1345+
"in_channels": 3,
1346+
"necks": [{"name": "FeaturePyramidNetworkNeck"}],
1347+
}
1348+
if extra_model_args:
1349+
model_args.update(extra_model_args)
1350+
return ObjectDetectionTask(
1351+
model_factory="ObjectDetectionModelFactory",
1352+
model_args=model_args,
1353+
lr=1e-4,
1354+
optimizer="Adam",
1355+
optimizer_hparams={},
1356+
scheduler=None,
1357+
scheduler_hparams={},
1358+
freeze_backbone=False,
1359+
freeze_decoder=False,
1360+
class_names=None,
1361+
iou_threshold=0.5,
1362+
score_threshold=0.5,
1363+
)
1364+
1365+
1366+
def _run_train_loop(task):
1367+
"""Run 2 train + 1 val epoch via Lightning Trainer on CPU."""
1368+
train_loader = DataLoader(
1369+
_TinyDetectionDataset(size=8),
1370+
batch_size=4,
1371+
collate_fn=_det_collate,
1372+
shuffle=True,
1373+
)
1374+
val_loader = DataLoader(
1375+
_TinyDetectionDataset(size=4),
1376+
batch_size=4,
1377+
collate_fn=_det_collate,
1378+
)
1379+
trainer = lightning.Trainer(
1380+
accelerator="cpu",
1381+
max_epochs=2,
1382+
enable_checkpointing=False,
1383+
enable_progress_bar=False,
1384+
log_every_n_steps=1,
1385+
limit_train_batches=2,
1386+
limit_val_batches=1,
1387+
)
1388+
trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)
1389+
return trainer
1390+
1391+
1392+
class TestDETRTrainingLoop(unittest.TestCase):
1393+
"""End-to-end Lightning training loop for vanilla DETR."""
1394+
1395+
def test_train_val_loop(self):
1396+
task = _make_task(
1397+
"detr",
1398+
{
1399+
"framework_d_model": 64,
1400+
"framework_nhead": 4,
1401+
"framework_num_encoder_layers": 1,
1402+
"framework_num_decoder_layers": 1,
1403+
"framework_dim_feedforward": 128,
1404+
"framework_num_queries": 10,
1405+
"framework_aux_loss": False,
1406+
},
1407+
)
1408+
trainer = _run_train_loop(task)
1409+
assert "train_loss" in trainer.callback_metrics
1410+
gc.collect()
1411+
1412+
def test_predict_after_training(self):
1413+
task = _make_task(
1414+
"detr",
1415+
{
1416+
"framework_d_model": 64,
1417+
"framework_nhead": 4,
1418+
"framework_num_encoder_layers": 1,
1419+
"framework_num_decoder_layers": 1,
1420+
"framework_dim_feedforward": 128,
1421+
"framework_num_queries": 10,
1422+
"framework_aux_loss": False,
1423+
},
1424+
)
1425+
_run_train_loop(task)
1426+
task.model.eval()
1427+
batch = {
1428+
"image": torch.randn(2, 3, 128, 128),
1429+
"boxes": [torch.tensor([[10, 10, 50, 50]], dtype=torch.float32)] * 2,
1430+
"labels": [torch.tensor([1])] * 2,
1431+
}
1432+
with torch.no_grad():
1433+
preds = task.predict_step(batch, batch_idx=0)
1434+
assert isinstance(preds, list)
1435+
assert len(preds) == 2
1436+
for p in preds:
1437+
assert "boxes" in p
1438+
assert "scores" in p
1439+
assert "labels" in p
1440+
gc.collect()
1441+
1442+
1443+
@requires_msdeform
1444+
class TestDeformableDETRTrainingLoop(unittest.TestCase):
1445+
"""End-to-end Lightning training loop for Deformable DETR."""
1446+
1447+
def test_train_val_loop(self):
1448+
task = _make_task(
1449+
"deformable-detr",
1450+
{
1451+
"framework_d_model": 64,
1452+
"framework_nhead": 4,
1453+
"framework_num_encoder_layers": 1,
1454+
"framework_num_decoder_layers": 1,
1455+
"framework_dim_feedforward": 128,
1456+
"framework_num_queries": 10,
1457+
"framework_aux_loss": False,
1458+
},
1459+
)
1460+
trainer = _run_train_loop(task)
1461+
assert "train_loss" in trainer.callback_metrics
1462+
gc.collect()
1463+
1464+
def test_predict_after_training(self):
1465+
task = _make_task(
1466+
"deformable-detr",
1467+
{
1468+
"framework_d_model": 64,
1469+
"framework_nhead": 4,
1470+
"framework_num_encoder_layers": 1,
1471+
"framework_num_decoder_layers": 1,
1472+
"framework_dim_feedforward": 128,
1473+
"framework_num_queries": 10,
1474+
"framework_aux_loss": False,
1475+
},
1476+
)
1477+
_run_train_loop(task)
1478+
task.model.eval()
1479+
batch = {
1480+
"image": torch.randn(2, 3, 128, 128),
1481+
"boxes": [torch.tensor([[10, 10, 50, 50]], dtype=torch.float32)] * 2,
1482+
"labels": [torch.tensor([1])] * 2,
1483+
}
1484+
with torch.no_grad():
1485+
preds = task.predict_step(batch, batch_idx=0)
1486+
assert isinstance(preds, list)
1487+
assert len(preds) == 2
1488+
for p in preds:
1489+
assert "boxes" in p
1490+
assert "scores" in p
1491+
assert "labels" in p
1492+
gc.collect()
1493+
1494+
1495+
class TestRFDETRTrainingLoop(unittest.TestCase):
1496+
"""End-to-end Lightning training loop for RF-DETR."""
1497+
1498+
def test_train_val_loop(self):
1499+
task = _make_task(
1500+
"rf-detr",
1501+
{
1502+
"framework_d_model": 64,
1503+
"framework_sa_nhead": 4,
1504+
"framework_ca_nhead": 4,
1505+
"framework_num_decoder_layers": 1,
1506+
"framework_dim_feedforward": 128,
1507+
"framework_num_queries": 10,
1508+
"framework_num_select": 10,
1509+
"framework_two_stage": True,
1510+
"framework_bbox_reparam": True,
1511+
"framework_aux_loss": False,
1512+
},
1513+
)
1514+
trainer = _run_train_loop(task)
1515+
assert "train_loss" in trainer.callback_metrics
1516+
gc.collect()
1517+
1518+
def test_predict_after_training(self):
1519+
task = _make_task(
1520+
"rf-detr",
1521+
{
1522+
"framework_d_model": 64,
1523+
"framework_sa_nhead": 4,
1524+
"framework_ca_nhead": 4,
1525+
"framework_num_decoder_layers": 1,
1526+
"framework_dim_feedforward": 128,
1527+
"framework_num_queries": 10,
1528+
"framework_num_select": 10,
1529+
"framework_two_stage": True,
1530+
"framework_bbox_reparam": True,
1531+
"framework_aux_loss": False,
1532+
},
1533+
)
1534+
_run_train_loop(task)
1535+
task.model.eval()
1536+
batch = {
1537+
"image": torch.randn(2, 3, 128, 128),
1538+
"boxes": [torch.tensor([[10, 10, 50, 50]], dtype=torch.float32)] * 2,
1539+
"labels": [torch.tensor([1])] * 2,
1540+
}
1541+
with torch.no_grad():
1542+
preds = task.predict_step(batch, batch_idx=0)
1543+
assert isinstance(preds, list)
1544+
assert len(preds) == 2
1545+
for p in preds:
1546+
assert "boxes" in p
1547+
assert "scores" in p
1548+
assert "labels" in p
1549+
gc.collect()
1550+
1551+
12991552
if __name__ == "__main__":
13001553
unittest.main()

0 commit comments

Comments
 (0)