|
4 | 4 | import unittest |
5 | 5 | from unittest.mock import patch |
6 | 6 |
|
| 7 | +import lightning |
7 | 8 | import pytest |
8 | 9 | import torch |
| 10 | +from torch.utils.data import DataLoader, Dataset |
9 | 11 |
|
10 | 12 | from terratorch.models.object_detection_model_factory import ( |
11 | 13 | ObjectDetectionModel, |
12 | 14 | ObjectDetectionModelFactory, |
13 | 15 | ) |
| 16 | +from terratorch.tasks.object_detection_task import ObjectDetectionTask |
14 | 17 |
|
15 | 18 | # --------------------------------------------------------------------------- |
16 | 19 | # Helpers |
@@ -1296,5 +1299,255 @@ def test_eval_returns_masks(self): |
1296 | 1299 | gc.collect() |
1297 | 1300 |
|
1298 | 1301 |
|
| 1302 | +# --------------------------------------------------------------------------- |
| 1303 | +# Lightning training-loop smoke tests |
| 1304 | +# --------------------------------------------------------------------------- |
| 1305 | + |
| 1306 | + |
| 1307 | +class _TinyDetectionDataset(Dataset): |
| 1308 | + """Random images with 1-3 random boxes per image.""" |
| 1309 | + |
| 1310 | + def __init__(self, size: int = 8, img_h: int = 128, img_w: int = 128, num_classes: int = 5): |
| 1311 | + self.size = size |
| 1312 | + self.img_h = img_h |
| 1313 | + self.img_w = img_w |
| 1314 | + self.num_classes = num_classes |
| 1315 | + |
| 1316 | + def __len__(self): |
| 1317 | + return self.size |
| 1318 | + |
| 1319 | + def __getitem__(self, idx): |
| 1320 | + image = torch.randn(3, self.img_h, self.img_w) |
| 1321 | + n = torch.randint(1, 4, (1,)).item() |
| 1322 | + x1 = torch.randint(0, self.img_w // 2, (n,)).float() |
| 1323 | + y1 = torch.randint(0, self.img_h // 2, (n,)).float() |
| 1324 | + x2 = (x1 + torch.randint(10, self.img_w // 2, (n,)).float()).clamp(max=self.img_w) |
| 1325 | + y2 = (y1 + torch.randint(10, self.img_h // 2, (n,)).float()).clamp(max=self.img_h) |
| 1326 | + boxes = torch.stack([x1, y1, x2, y2], dim=1) |
| 1327 | + labels = torch.randint(1, self.num_classes, (n,)) |
| 1328 | + return {"image": image, "boxes": boxes, "labels": labels} |
| 1329 | + |
| 1330 | + |
| 1331 | +def _det_collate(batch): |
| 1332 | + images = torch.stack([b["image"] for b in batch]) |
| 1333 | + boxes = [b["boxes"] for b in batch] |
| 1334 | + labels = [b["labels"] for b in batch] |
| 1335 | + return {"image": images, "boxes": boxes, "labels": labels} |
| 1336 | + |
| 1337 | + |
| 1338 | +def _make_task(framework, extra_model_args=None): |
| 1339 | + """Build a minimal ObjectDetectionTask for a given framework.""" |
| 1340 | + model_args = { |
| 1341 | + "framework": framework, |
| 1342 | + "backbone": "timm_resnet18", |
| 1343 | + "backbone_pretrained": False, |
| 1344 | + "num_classes": 5, |
| 1345 | + "in_channels": 3, |
| 1346 | + "necks": [{"name": "FeaturePyramidNetworkNeck"}], |
| 1347 | + } |
| 1348 | + if extra_model_args: |
| 1349 | + model_args.update(extra_model_args) |
| 1350 | + return ObjectDetectionTask( |
| 1351 | + model_factory="ObjectDetectionModelFactory", |
| 1352 | + model_args=model_args, |
| 1353 | + lr=1e-4, |
| 1354 | + optimizer="Adam", |
| 1355 | + optimizer_hparams={}, |
| 1356 | + scheduler=None, |
| 1357 | + scheduler_hparams={}, |
| 1358 | + freeze_backbone=False, |
| 1359 | + freeze_decoder=False, |
| 1360 | + class_names=None, |
| 1361 | + iou_threshold=0.5, |
| 1362 | + score_threshold=0.5, |
| 1363 | + ) |
| 1364 | + |
| 1365 | + |
| 1366 | +def _run_train_loop(task): |
| 1367 | + """Run 2 train + 1 val epoch via Lightning Trainer on CPU.""" |
| 1368 | + train_loader = DataLoader( |
| 1369 | + _TinyDetectionDataset(size=8), |
| 1370 | + batch_size=4, |
| 1371 | + collate_fn=_det_collate, |
| 1372 | + shuffle=True, |
| 1373 | + ) |
| 1374 | + val_loader = DataLoader( |
| 1375 | + _TinyDetectionDataset(size=4), |
| 1376 | + batch_size=4, |
| 1377 | + collate_fn=_det_collate, |
| 1378 | + ) |
| 1379 | + trainer = lightning.Trainer( |
| 1380 | + accelerator="cpu", |
| 1381 | + max_epochs=2, |
| 1382 | + enable_checkpointing=False, |
| 1383 | + enable_progress_bar=False, |
| 1384 | + log_every_n_steps=1, |
| 1385 | + limit_train_batches=2, |
| 1386 | + limit_val_batches=1, |
| 1387 | + ) |
| 1388 | + trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader) |
| 1389 | + return trainer |
| 1390 | + |
| 1391 | + |
| 1392 | +class TestDETRTrainingLoop(unittest.TestCase): |
| 1393 | + """End-to-end Lightning training loop for vanilla DETR.""" |
| 1394 | + |
| 1395 | + def test_train_val_loop(self): |
| 1396 | + task = _make_task( |
| 1397 | + "detr", |
| 1398 | + { |
| 1399 | + "framework_d_model": 64, |
| 1400 | + "framework_nhead": 4, |
| 1401 | + "framework_num_encoder_layers": 1, |
| 1402 | + "framework_num_decoder_layers": 1, |
| 1403 | + "framework_dim_feedforward": 128, |
| 1404 | + "framework_num_queries": 10, |
| 1405 | + "framework_aux_loss": False, |
| 1406 | + }, |
| 1407 | + ) |
| 1408 | + trainer = _run_train_loop(task) |
| 1409 | + assert "train_loss" in trainer.callback_metrics |
| 1410 | + gc.collect() |
| 1411 | + |
| 1412 | + def test_predict_after_training(self): |
| 1413 | + task = _make_task( |
| 1414 | + "detr", |
| 1415 | + { |
| 1416 | + "framework_d_model": 64, |
| 1417 | + "framework_nhead": 4, |
| 1418 | + "framework_num_encoder_layers": 1, |
| 1419 | + "framework_num_decoder_layers": 1, |
| 1420 | + "framework_dim_feedforward": 128, |
| 1421 | + "framework_num_queries": 10, |
| 1422 | + "framework_aux_loss": False, |
| 1423 | + }, |
| 1424 | + ) |
| 1425 | + _run_train_loop(task) |
| 1426 | + task.model.eval() |
| 1427 | + batch = { |
| 1428 | + "image": torch.randn(2, 3, 128, 128), |
| 1429 | + "boxes": [torch.tensor([[10, 10, 50, 50]], dtype=torch.float32)] * 2, |
| 1430 | + "labels": [torch.tensor([1])] * 2, |
| 1431 | + } |
| 1432 | + with torch.no_grad(): |
| 1433 | + preds = task.predict_step(batch, batch_idx=0) |
| 1434 | + assert isinstance(preds, list) |
| 1435 | + assert len(preds) == 2 |
| 1436 | + for p in preds: |
| 1437 | + assert "boxes" in p |
| 1438 | + assert "scores" in p |
| 1439 | + assert "labels" in p |
| 1440 | + gc.collect() |
| 1441 | + |
| 1442 | + |
| 1443 | +@requires_msdeform |
| 1444 | +class TestDeformableDETRTrainingLoop(unittest.TestCase): |
| 1445 | + """End-to-end Lightning training loop for Deformable DETR.""" |
| 1446 | + |
| 1447 | + def test_train_val_loop(self): |
| 1448 | + task = _make_task( |
| 1449 | + "deformable-detr", |
| 1450 | + { |
| 1451 | + "framework_d_model": 64, |
| 1452 | + "framework_nhead": 4, |
| 1453 | + "framework_num_encoder_layers": 1, |
| 1454 | + "framework_num_decoder_layers": 1, |
| 1455 | + "framework_dim_feedforward": 128, |
| 1456 | + "framework_num_queries": 10, |
| 1457 | + "framework_aux_loss": False, |
| 1458 | + }, |
| 1459 | + ) |
| 1460 | + trainer = _run_train_loop(task) |
| 1461 | + assert "train_loss" in trainer.callback_metrics |
| 1462 | + gc.collect() |
| 1463 | + |
| 1464 | + def test_predict_after_training(self): |
| 1465 | + task = _make_task( |
| 1466 | + "deformable-detr", |
| 1467 | + { |
| 1468 | + "framework_d_model": 64, |
| 1469 | + "framework_nhead": 4, |
| 1470 | + "framework_num_encoder_layers": 1, |
| 1471 | + "framework_num_decoder_layers": 1, |
| 1472 | + "framework_dim_feedforward": 128, |
| 1473 | + "framework_num_queries": 10, |
| 1474 | + "framework_aux_loss": False, |
| 1475 | + }, |
| 1476 | + ) |
| 1477 | + _run_train_loop(task) |
| 1478 | + task.model.eval() |
| 1479 | + batch = { |
| 1480 | + "image": torch.randn(2, 3, 128, 128), |
| 1481 | + "boxes": [torch.tensor([[10, 10, 50, 50]], dtype=torch.float32)] * 2, |
| 1482 | + "labels": [torch.tensor([1])] * 2, |
| 1483 | + } |
| 1484 | + with torch.no_grad(): |
| 1485 | + preds = task.predict_step(batch, batch_idx=0) |
| 1486 | + assert isinstance(preds, list) |
| 1487 | + assert len(preds) == 2 |
| 1488 | + for p in preds: |
| 1489 | + assert "boxes" in p |
| 1490 | + assert "scores" in p |
| 1491 | + assert "labels" in p |
| 1492 | + gc.collect() |
| 1493 | + |
| 1494 | + |
| 1495 | +class TestRFDETRTrainingLoop(unittest.TestCase): |
| 1496 | + """End-to-end Lightning training loop for RF-DETR.""" |
| 1497 | + |
| 1498 | + def test_train_val_loop(self): |
| 1499 | + task = _make_task( |
| 1500 | + "rf-detr", |
| 1501 | + { |
| 1502 | + "framework_d_model": 64, |
| 1503 | + "framework_sa_nhead": 4, |
| 1504 | + "framework_ca_nhead": 4, |
| 1505 | + "framework_num_decoder_layers": 1, |
| 1506 | + "framework_dim_feedforward": 128, |
| 1507 | + "framework_num_queries": 10, |
| 1508 | + "framework_num_select": 10, |
| 1509 | + "framework_two_stage": True, |
| 1510 | + "framework_bbox_reparam": True, |
| 1511 | + "framework_aux_loss": False, |
| 1512 | + }, |
| 1513 | + ) |
| 1514 | + trainer = _run_train_loop(task) |
| 1515 | + assert "train_loss" in trainer.callback_metrics |
| 1516 | + gc.collect() |
| 1517 | + |
| 1518 | + def test_predict_after_training(self): |
| 1519 | + task = _make_task( |
| 1520 | + "rf-detr", |
| 1521 | + { |
| 1522 | + "framework_d_model": 64, |
| 1523 | + "framework_sa_nhead": 4, |
| 1524 | + "framework_ca_nhead": 4, |
| 1525 | + "framework_num_decoder_layers": 1, |
| 1526 | + "framework_dim_feedforward": 128, |
| 1527 | + "framework_num_queries": 10, |
| 1528 | + "framework_num_select": 10, |
| 1529 | + "framework_two_stage": True, |
| 1530 | + "framework_bbox_reparam": True, |
| 1531 | + "framework_aux_loss": False, |
| 1532 | + }, |
| 1533 | + ) |
| 1534 | + _run_train_loop(task) |
| 1535 | + task.model.eval() |
| 1536 | + batch = { |
| 1537 | + "image": torch.randn(2, 3, 128, 128), |
| 1538 | + "boxes": [torch.tensor([[10, 10, 50, 50]], dtype=torch.float32)] * 2, |
| 1539 | + "labels": [torch.tensor([1])] * 2, |
| 1540 | + } |
| 1541 | + with torch.no_grad(): |
| 1542 | + preds = task.predict_step(batch, batch_idx=0) |
| 1543 | + assert isinstance(preds, list) |
| 1544 | + assert len(preds) == 2 |
| 1545 | + for p in preds: |
| 1546 | + assert "boxes" in p |
| 1547 | + assert "scores" in p |
| 1548 | + assert "labels" in p |
| 1549 | + gc.collect() |
| 1550 | + |
| 1551 | + |
1299 | 1552 | if __name__ == "__main__": |
1300 | 1553 | unittest.main() |
0 commit comments