Skip to content

Comments

Add DETR, Deformable DETR, and RF-DETR Object Detection / Instance Segmentation Network Architectures#1113

Open
josiahwsmith10 wants to merge 10 commits intoterrastackai:mainfrom
josiahwsmith10:add_DETR_models
Open

Add DETR, Deformable DETR, and RF-DETR Object Detection / Instance Segmentation Network Architectures#1113
josiahwsmith10 wants to merge 10 commits intoterrastackai:mainfrom
josiahwsmith10:add_DETR_models

Conversation

@josiahwsmith10
Copy link

Closes #1111

Copy verbatim reference implementations for provenance tracking:
- detr.py, transformer.py, matcher.py, position_encoding.py from facebook/detr
- deformable_detr.py, deformable_transformer.py from fundamentalvision/Deformable-DETR

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Modify the verbatim reference files from facebook/detr and
fundamentalvision/Deformable-DETR for use with TerraTorch's
ObjectDetectionModelFactory:

- Remove external dependencies (NestedTensor, util.box_ops,
  util.misc, segmentation, backbone modules) and replace with
  torchvision.ops equivalents
- Add terratorch_detr.py with TerraTorchDETR and
  TerraTorchDeformableDETR wrapper classes that compose reference
  components with BackboneWrapper
- Update __init__.py to re-export wrapper classes as DETR and
  DeformableDETR
- Integrate with ObjectDetectionModelFactory ('detr' and
  'deformable-detr' framework options)
- Add tests and example YAML config
- Ruff lint and format compliance

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
… DETR and Deformable DETR

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Copy verbatim model files from https://github.com/roboflow/rf-detr
(commit 82ad5d1) under Apache-2.0 License for adaptation into
TerraTorch's ObjectDetectionModelFactory.

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Verbatim copy of facebook/detr models/segmentation.py containing mask
prediction heads (MHAttentionMap, MaskHeadSmallConv), losses (dice_loss,
sigmoid_focal_loss), and post-processing (PostProcessSegm, PostProcessPanoptic).
Will be adapted for TerraTorch in a subsequent commit.

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Adapt reference segmentation.py for TerraTorch (remove DETRsegm wrapper,
PostProcessPanoptic, NestedTensor; keep MHAttentionMap, MaskHeadSmallConv,
dice_loss, sigmoid_focal_loss, PostProcessSegm). Add loss_masks to both
SetCriterion classes with masks skipped in aux loss loop. Integrate mask
head into TerraTorchDETR and TerraTorchDeformableDETR via masks=True
parameter. Update ObjectDetectionTask for DETR segmentation metrics/masks.
Add component, criterion, and end-to-end segmentation tests.

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Integrate RF-DETR (Roboflow Detection Transformer) into TerraTorch's
ObjectDetectionModelFactory alongside existing DETR and Deformable DETR.

- Add TerraTorchRFDETR wrapper with decoder-only LW-DETR transformer,
  two-stage encoder proposals, bbox reparameterization, and pure-PyTorch
  multi-scale deformable attention (no CUDA extension required)
- Register 'rf-detr' framework in ObjectDetectionModelFactory
- Remove vendored DINOv2 backbone (replaced by TerraTorch BackboneWrapper)
- Adapt all RF-DETR components for ruff compliance (F→f_nn, assert→raise,
  Optional→union syntax, naming conventions, etc.)
- Add comprehensive parity tests proving numerical identity with reference:
  gen_encoder_output_proposals, full LWDETR forward (no two-stage,
  two-stage+bbox_reparam, two-stage without), SetCriterion with ia_bce_loss
  and varifocal_loss including aux/enc outputs, position encoding, matcher,
  PostProcess, and SegmentationHead
- Add functional tests for factory build, train/eval forward, aux/enc loss,
  and individual component validation

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
…andards

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
…ultidimensional tensor indexing in SetCriterion (ia_bce / varifocal) in lwdetr.py

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
- Fix PyTorch 2.9 deprecation: use tuple indexing instead of list for
  multidimensional tensor indexing in SetCriterion (ia_bce / varifocal)
- Add end-to-end Lightning training-loop smoke tests for DETR,
  Deformable DETR, and RF-DETR (train + val + predict on CPU with
  timm_resnet18 backbone and synthetic data)

Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
@josiahwsmith10
Copy link
Author

Here's a test script to compare parity with original DETR implementations.

"""Reference parity tests: prove numerical identity with original DETR code.

These tests verify our adapted DETR, Deformable DETR, and RF-DETR code
produces identical outputs to the original reference implementations.

Tests skip gracefully when reference code is absent.  To set up the
reference repos, clone them into /private/tmp/ (or update the path
constants below):

    # DETR (Facebook Research)
    git clone https://github.com/facebookresearch/detr.git /private/tmp/detr-ref

    # Deformable DETR (Fundamentalvision)
    git clone https://github.com/fundamentalvision/Deformable-DETR.git /private/tmp/deformable-detr-ref

    # RF-DETR (Roboflow) — only the src/ subtree is needed
    git clone https://github.com/roboflow/rf-detr.git /private/tmp/rf-detr-ref
    # reference path points to /private/tmp/rf-detr-ref/src

Note: Deformable DETR tests also require the CUDA MSDeformAttn extension to
be compiled and importable (will not run on CPU-only machines).
"""

from __future__ import annotations

import sys
import types
from pathlib import Path

import pytest
import torch
from torch import nn

from terratorch.models.detr.detr import DETR, PostProcess, SetCriterion
from terratorch.models.detr.matcher import HungarianMatcher
from terratorch.models.detr.position_encoding import PositionEmbeddingSine
from terratorch.models.detr.transformer import Transformer

# Deformable DETR requires the CUDA MSDeformAttn extension at import time.
try:
    from terratorch.models.detr.deformable_detr import PostProcess as DeformablePostProcess
    from terratorch.models.detr.deformable_detr import SetCriterion as DeformableSetCriterion

    _adapted_deformable_available = True
except ImportError:
    _adapted_deformable_available = False

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_REF_DETR_PATH = "/private/tmp/detr-ref"
_REF_DEFORMABLE_DETR_PATH = "/private/tmp/deformable-detr-ref"
ATOL = 1e-6

_ref_detr_available = Path(_REF_DETR_PATH).is_dir()
_ref_deformable_detr_available = Path(_REF_DEFORMABLE_DETR_PATH).is_dir()

requires_ref_detr = pytest.mark.skipif(
    not _ref_detr_available,
    reason=f"Reference DETR not found at {_REF_DETR_PATH}",
)
requires_ref_deformable_detr = pytest.mark.skipif(
    not (_ref_deformable_detr_available and _adapted_deformable_available),
    reason="Reference Deformable DETR or CUDA MSDeformAttn extension not available",
)


# ---------------------------------------------------------------------------
# Reference module loading with sys.path isolation
# ---------------------------------------------------------------------------
def _cleanup_ref_modules():
    """Remove reference ``models`` and ``util`` packages from sys.modules."""
    for key in list(sys.modules):
        if key == "models" or key.startswith("models.") or key == "util" or key.startswith("util."):
            del sys.modules[key]


_cached_ref_detr: dict | None = None


def _load_ref_detr():
    """Import reference DETR classes.  Result is cached after first call."""
    global _cached_ref_detr  # noqa: PLW0603
    if _cached_ref_detr is not None:
        return _cached_ref_detr

    _cleanup_ref_modules()
    sys.path.insert(0, _REF_DETR_PATH)
    try:
        from models.detr import DETR as RefDETR  # noqa: N811, PLC0415
        from models.detr import PostProcess as RefPostProcess  # noqa: PLC0415
        from models.detr import SetCriterion as RefSetCriterion  # noqa: PLC0415
        from models.matcher import HungarianMatcher as RefMatcher  # noqa: PLC0415
        from models.position_encoding import PositionEmbeddingSine as RefPosSine  # noqa: PLC0415
        from models.transformer import Transformer as RefTransformer  # noqa: PLC0415
        from util.misc import NestedTensor  # noqa: PLC0415

        _cached_ref_detr = {
            "DETR": RefDETR,
            "SetCriterion": RefSetCriterion,
            "PostProcess": RefPostProcess,
            "Matcher": RefMatcher,
            "Transformer": RefTransformer,
            "PositionEmbeddingSine": RefPosSine,
            "NestedTensor": NestedTensor,
        }
        return _cached_ref_detr
    finally:
        sys.path.remove(_REF_DETR_PATH)
        _cleanup_ref_modules()


_cached_ref_deformable: dict | None = None


def _load_ref_deformable_detr():
    """Import reference Deformable DETR classes.  Result is cached after first call."""
    global _cached_ref_deformable  # noqa: PLW0603
    if _cached_ref_deformable is not None:
        return _cached_ref_deformable

    _cleanup_ref_modules()
    sys.path.insert(0, _REF_DEFORMABLE_DETR_PATH)
    try:
        # Mock the CUDA extension (MSDeformAttn) so deformable_transformer.py
        # can be imported without building the native extension.
        mock_ops = types.ModuleType("models.ops")
        mock_ops_modules = types.ModuleType("models.ops.modules")
        mock_ops_modules.MSDeformAttn = type("MSDeformAttn", (nn.Module,), {})
        mock_ops.modules = mock_ops_modules
        sys.modules["models.ops"] = mock_ops
        sys.modules["models.ops.modules"] = mock_ops_modules

        from models.deformable_detr import PostProcess as RefDefPostProcess  # noqa: PLC0415
        from models.deformable_detr import SetCriterion as RefDefSetCriterion  # noqa: PLC0415

        _cached_ref_deformable = {
            "SetCriterion": RefDefSetCriterion,
            "PostProcess": RefDefPostProcess,
        }
        return _cached_ref_deformable
    except Exception:
        return None
    finally:
        sys.path.remove(_REF_DEFORMABLE_DETR_PATH)
        _cleanup_ref_modules()


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def ref_detr():
    return _load_ref_detr()


@pytest.fixture()
def ref_deformable_detr():
    result = _load_ref_deformable_detr()
    if result is None:
        pytest.skip("Could not import reference Deformable DETR modules")
    return result


# ---------------------------------------------------------------------------
# Mock backbones
# ---------------------------------------------------------------------------
class _SimpleBackbone(nn.Module):
    """Parameterless backbone stub exposing only ``num_channels``."""

    def __init__(self, num_channels):
        super().__init__()
        self.num_channels = num_channels


class _MockRefBackbone(nn.Module):
    """Mock backbone that returns fixed features for the reference DETR.

    The reference DETR calls ``self.backbone(samples)`` which must return
    ``([NestedTensor(src, mask)], [pos])``.  We ignore the input samples
    and return the pre-stored tensors so that the same ``(src, mask, pos)``
    flows through both adapted and reference models.
    """

    def __init__(self, num_channels, src, mask, pos, nested_tensor_cls):
        super().__init__()
        self.num_channels = num_channels
        self._src = src
        self._mask = mask
        self._pos = pos
        self._nt_cls = nested_tensor_cls

    def forward(self, _samples):
        return [self._nt_cls(self._src, self._mask)], [self._pos]


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _dummy_targets(batch_size, num_classes):
    """Create deterministic dummy targets (caller must seed the RNG)."""
    targets = []
    for i in range(batch_size):
        n = i + 1  # 1 object in first image, 2 in second, etc.
        targets.append(
            {
                "labels": torch.randint(0, num_classes, (n,)),
                "boxes": torch.rand(n, 4).clamp(0.1, 0.9),
            }
        )
    return targets


# ===================================================================
# 1. Position Encoding
# ===================================================================
@requires_ref_detr
class TestPositionEncodingParity:
    def test_sine_encoding_matches_reference(self, ref_detr):
        """Stateless sine PE: adapted pe(x) vs reference pe(NestedTensor)."""
        torch.manual_seed(42)
        adapted_pe = PositionEmbeddingSine(64, normalize=True)
        ref_pe = ref_detr["PositionEmbeddingSine"](64, normalize=True)

        x = torch.randn(2, 128, 16, 16)
        mask = torch.zeros(2, 16, 16, dtype=torch.bool)

        adapted_out = adapted_pe(x)  # creates all-False mask internally
        ref_out = ref_pe(ref_detr["NestedTensor"](x, mask))

        assert torch.allclose(adapted_out, ref_out, atol=ATOL), (
            f"Max diff: {(adapted_out - ref_out).abs().max().item()}"
        )


# ===================================================================
# 2. Matcher
# ===================================================================
@requires_ref_detr
class TestMatcherParity:
    def test_hungarian_matcher_identical_indices(self, ref_detr):
        """Same cost coefficients, same inputs → same matching indices."""
        torch.manual_seed(42)
        adapted = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
        ref = ref_detr["Matcher"](cost_class=1, cost_bbox=5, cost_giou=2)

        outputs = {
            "pred_logits": torch.randn(2, 10, 6),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
        }
        targets = [
            {"labels": torch.tensor([0]), "boxes": torch.tensor([[0.5, 0.5, 0.2, 0.2]])},
            {
                "labels": torch.tensor([1, 3]),
                "boxes": torch.tensor([[0.3, 0.3, 0.1, 0.1], [0.7, 0.7, 0.3, 0.3]]),
            },
        ]

        adapted_idx = adapted(outputs, targets)
        ref_idx = ref(outputs, targets)

        for (ai, aj), (ri, rj) in zip(adapted_idx, ref_idx, strict=True):
            assert torch.equal(ai, ri), f"Pred idx mismatch: {ai} vs {ri}"
            assert torch.equal(aj, rj), f"Tgt idx mismatch: {aj} vs {rj}"


# ===================================================================
# 3. Transformer
# ===================================================================
@requires_ref_detr
class TestTransformerParity:
    @staticmethod
    def _build_pair(ref_detr, return_intermediate):
        d_model, nhead = 64, 4
        adapted = Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=128,
            dropout=0.0,
            return_intermediate_dec=return_intermediate,
        )
        ref = ref_detr["Transformer"](
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=128,
            dropout=0.0,
            return_intermediate_dec=return_intermediate,
        )
        ref.load_state_dict(adapted.state_dict())
        adapted.eval()
        ref.eval()
        return adapted, ref

    def test_no_intermediate_output_matches(self, ref_detr):
        """return_intermediate_dec=False: hs [1,B,Q,D] and memory match."""
        torch.manual_seed(42)
        adapted, ref = self._build_pair(ref_detr, return_intermediate=False)

        src = torch.randn(2, 64, 8, 8)
        mask = torch.zeros(2, 8, 8, dtype=torch.bool)
        query_embed = torch.randn(10, 64)
        pos_embed = torch.randn(2, 64, 8, 8)

        a_hs, a_mem = adapted(src, mask, query_embed, pos_embed)
        r_hs, r_mem = ref(src, mask, query_embed, pos_embed)

        assert torch.allclose(a_hs, r_hs, atol=ATOL), f"hs diff: {(a_hs - r_hs).abs().max()}"
        assert torch.allclose(a_mem, r_mem, atol=ATOL), f"mem diff: {(a_mem - r_mem).abs().max()}"

    def test_intermediate_output_matches(self, ref_detr):
        """return_intermediate_dec=True: all decoder layer outputs match."""
        torch.manual_seed(42)
        adapted, ref = self._build_pair(ref_detr, return_intermediate=True)

        src = torch.randn(2, 64, 8, 8)
        mask = torch.zeros(2, 8, 8, dtype=torch.bool)
        query_embed = torch.randn(10, 64)
        pos_embed = torch.randn(2, 64, 8, 8)

        a_hs, a_mem = adapted(src, mask, query_embed, pos_embed)
        r_hs, r_mem = ref(src, mask, query_embed, pos_embed)

        assert torch.allclose(a_hs, r_hs, atol=ATOL), f"hs diff: {(a_hs - r_hs).abs().max()}"
        assert torch.allclose(a_mem, r_mem, atol=ATOL), f"mem diff: {(a_mem - r_mem).abs().max()}"


# ===================================================================
# 4. SetCriterion (DETR)
# ===================================================================
@requires_ref_detr
class TestSetCriterionParity:
    @staticmethod
    def _build_pair(ref_detr, num_classes=5):
        # Use the SAME matcher instance so matching indices are identical.
        matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
        weight_dict = {"loss_ce": 1, "loss_bbox": 5, "loss_giou": 2}
        losses = ["labels", "boxes", "cardinality"]

        adapted = SetCriterion(num_classes, matcher, weight_dict, eos_coef=0.1, losses=losses)
        ref = ref_detr["SetCriterion"](num_classes, matcher, weight_dict, eos_coef=0.1, losses=losses)
        # Copy the empty_weight buffer so both have identical class weights.
        ref.load_state_dict(adapted.state_dict())
        return adapted, ref

    def test_loss_values_match(self, ref_detr):
        """All loss scalars match for a basic forward pass."""
        torch.manual_seed(42)
        adapted, ref = self._build_pair(ref_detr)

        outputs = {
            "pred_logits": torch.randn(2, 10, 6),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
        }
        targets = _dummy_targets(2, 5)

        a_losses = adapted(outputs, targets)
        r_losses = ref(outputs, targets)

        for key in a_losses:
            assert key in r_losses, f"Missing key: {key}"
            assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
                f"{key}: {a_losses[key].item():.8f} vs {r_losses[key].item():.8f}"
            )

    def test_loss_with_aux_outputs_match(self, ref_detr):
        """Auxiliary decoder losses (loss_ce_0, loss_bbox_0, ...) also match."""
        torch.manual_seed(42)
        adapted, ref = self._build_pair(ref_detr)

        outputs = {
            "pred_logits": torch.randn(2, 10, 6),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
            "aux_outputs": [
                {
                    "pred_logits": torch.randn(2, 10, 6),
                    "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
                },
                {
                    "pred_logits": torch.randn(2, 10, 6),
                    "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
                },
            ],
        }
        targets = _dummy_targets(2, 5)

        a_losses = adapted(outputs, targets)
        r_losses = ref(outputs, targets)

        for key in a_losses:
            assert key in r_losses, f"Missing key: {key}"
            assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
                f"{key}: {a_losses[key].item():.8f} vs {r_losses[key].item():.8f}"
            )


# ===================================================================
# 5. PostProcess (DETR)
# ===================================================================
@requires_ref_detr
class TestPostProcessParity:
    def test_postprocess_output_matches(self, ref_detr):
        """Stateless post-processor: scores, labels, boxes all match."""
        torch.manual_seed(42)
        adapted = PostProcess()
        ref = ref_detr["PostProcess"]()

        outputs = {
            "pred_logits": torch.randn(2, 10, 6),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
        }
        target_sizes = torch.tensor([[480, 640], [320, 480]])

        a_results = adapted(outputs, target_sizes)
        r_results = ref(outputs, target_sizes)

        for a, r in zip(a_results, r_results, strict=True):
            assert torch.allclose(a["scores"], r["scores"], atol=ATOL)
            assert torch.equal(a["labels"], r["labels"])
            assert torch.allclose(a["boxes"], r["boxes"], atol=ATOL)


# ===================================================================
# 6. DETR Model (inner class, bypassing backbone)
# ===================================================================
@requires_ref_detr
class TestDETRModelParity:
    @staticmethod
    def _build_pair(ref_detr, *, aux_loss, src, mask, pos):
        """Build adapted and reference DETR with shared weights."""
        d_model, nhead = 64, 4
        num_classes, num_queries = 5, 10
        num_channels = src.shape[1]
        nested_tensor_cls = ref_detr["NestedTensor"]

        adapted_model = DETR(
            _SimpleBackbone(num_channels),
            Transformer(
                d_model=d_model,
                nhead=nhead,
                num_encoder_layers=2,
                num_decoder_layers=2,
                dim_feedforward=128,
                dropout=0.0,
            ),
            num_classes=num_classes,
            num_queries=num_queries,
            aux_loss=aux_loss,
        )

        ref_model = ref_detr["DETR"](
            _MockRefBackbone(num_channels, src, mask, pos, nested_tensor_cls),
            ref_detr["Transformer"](
                d_model=d_model,
                nhead=nhead,
                num_encoder_layers=2,
                num_decoder_layers=2,
                dim_feedforward=128,
                dropout=0.0,
            ),
            num_classes=num_classes,
            num_queries=num_queries,
            aux_loss=aux_loss,
        )

        # All non-backbone param names are identical → strict load succeeds.
        ref_model.load_state_dict(adapted_model.state_dict())
        adapted_model.eval()
        ref_model.eval()
        return adapted_model, ref_model, nested_tensor_cls

    def test_detr_forward_no_aux_loss(self, ref_detr):
        """pred_logits and pred_boxes match without auxiliary loss."""
        torch.manual_seed(42)
        src = torch.randn(2, 32, 8, 8)
        mask = torch.zeros(2, 8, 8, dtype=torch.bool)
        pos = torch.randn(2, 64, 8, 8)

        adapted, ref, nt_cls = self._build_pair(
            ref_detr,
            aux_loss=False,
            src=src,
            mask=mask,
            pos=pos,
        )

        a_out = adapted(src, mask, pos)
        # The NestedTensor content doesn't matter — mock backbone ignores it.
        r_out = ref(nt_cls(torch.randn(2, 3, 8, 8), mask))

        assert torch.allclose(a_out["pred_logits"], r_out["pred_logits"], atol=ATOL), (
            f"logits diff: {(a_out['pred_logits'] - r_out['pred_logits']).abs().max()}"
        )
        assert torch.allclose(a_out["pred_boxes"], r_out["pred_boxes"], atol=ATOL), (
            f"boxes diff: {(a_out['pred_boxes'] - r_out['pred_boxes']).abs().max()}"
        )
        assert "aux_outputs" not in a_out
        assert "aux_outputs" not in r_out

    def test_detr_forward_with_aux_loss(self, ref_detr):
        """pred_logits, pred_boxes, and all aux_outputs match."""
        torch.manual_seed(42)
        src = torch.randn(2, 32, 8, 8)
        mask = torch.zeros(2, 8, 8, dtype=torch.bool)
        pos = torch.randn(2, 64, 8, 8)

        adapted, ref, nt_cls = self._build_pair(
            ref_detr,
            aux_loss=True,
            src=src,
            mask=mask,
            pos=pos,
        )

        a_out = adapted(src, mask, pos)
        r_out = ref(nt_cls(torch.randn(2, 3, 8, 8), mask))

        assert torch.allclose(a_out["pred_logits"], r_out["pred_logits"], atol=ATOL)
        assert torch.allclose(a_out["pred_boxes"], r_out["pred_boxes"], atol=ATOL)

        assert len(a_out["aux_outputs"]) == len(r_out["aux_outputs"])
        for i, (a, r) in enumerate(zip(a_out["aux_outputs"], r_out["aux_outputs"], strict=True)):
            assert torch.allclose(a["pred_logits"], r["pred_logits"], atol=ATOL), (
                f"aux[{i}] logits diff: {(a['pred_logits'] - r['pred_logits']).abs().max()}"
            )
            assert torch.allclose(a["pred_boxes"], r["pred_boxes"], atol=ATOL), (
                f"aux[{i}] boxes diff: {(a['pred_boxes'] - r['pred_boxes']).abs().max()}"
            )


# ===================================================================
# 7. Deformable DETR — SetCriterion and PostProcess
# ===================================================================
@requires_ref_deformable_detr
class TestDeformableSetCriterionParity:
    def test_loss_values_match(self, ref_deformable_detr):
        """Focal-loss-based SetCriterion produces identical loss values."""
        torch.manual_seed(42)
        num_classes = 5
        # Use the SAME matcher for both so matching indices are identical.
        matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
        weight_dict = {"loss_ce": 2, "loss_bbox": 5, "loss_giou": 2}
        losses = ["labels", "boxes", "cardinality"]

        adapted = DeformableSetCriterion(
            num_classes,
            matcher,
            weight_dict,
            losses=losses,
            focal_alpha=0.25,
        )
        ref = ref_deformable_detr["SetCriterion"](
            num_classes,
            matcher,
            weight_dict,
            losses=losses,
            focal_alpha=0.25,
        )

        outputs = {
            "pred_logits": torch.randn(2, 10, num_classes),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
        }
        targets = _dummy_targets(2, num_classes)

        a_losses = adapted(outputs, targets)
        r_losses = ref(outputs, targets)

        for key in a_losses:
            assert key in r_losses, f"Missing key: {key}"
            assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
                f"{key}: {a_losses[key].item():.8f} vs {r_losses[key].item():.8f}"
            )


@requires_ref_deformable_detr
class TestDeformablePostProcessParity:
    def test_postprocess_output_matches(self, ref_deformable_detr):
        """Top-k sigmoid post-processor: scores, labels, boxes all match."""
        torch.manual_seed(42)
        adapted = DeformablePostProcess()
        ref = ref_deformable_detr["PostProcess"]()

        outputs = {
            "pred_logits": torch.randn(2, 10, 5),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
        }
        target_sizes = torch.tensor([[480, 640], [320, 480]])

        a_results = adapted(outputs, target_sizes)
        r_results = ref(outputs, target_sizes)

        for a, r in zip(a_results, r_results, strict=True):
            assert torch.allclose(a["scores"], r["scores"], atol=ATOL)
            assert torch.equal(a["labels"], r["labels"])
            assert torch.allclose(a["boxes"], r["boxes"], atol=ATOL)


# ===================================================================
# 8. RF-DETR Parity Tests
# ===================================================================
_REF_RFDETR_PATH = "/private/tmp/rf-detr-ref/src"
_ref_rfdetr_available = Path(_REF_RFDETR_PATH).is_dir()

requires_ref_rfdetr = pytest.mark.skipif(
    not _ref_rfdetr_available,
    reason=f"Reference RF-DETR not found at {_REF_RFDETR_PATH}",
)

_cached_ref_rfdetr: dict | None = None

# RF-DETR reference modules to clean up after loading
_RFDETR_MODULE_PREFIXES = ("rfdetr",)


def _cleanup_rfdetr_modules():
    """Remove reference rfdetr modules from sys.modules."""
    for key in list(sys.modules):
        if key == "rfdetr" or key.startswith("rfdetr."):
            del sys.modules[key]


def _load_ref_rfdetr():
    """Import reference RF-DETR classes.  Result is cached after first call.

    RF-DETR's reference code has heavy transitive dependencies (transformers,
    peft, supervision, DINOv2 backbone). We mock these and manually set up
    the package hierarchy to only load the model/loss/matcher files we need.
    """
    import importlib  # noqa: PLC0415
    import logging  # noqa: PLC0415
    from unittest.mock import MagicMock  # noqa: PLC0415

    global _cached_ref_rfdetr  # noqa: PLW0603
    if _cached_ref_rfdetr is not None:
        return _cached_ref_rfdetr

    _cleanup_rfdetr_modules()
    ref_path = _REF_RFDETR_PATH
    sys.path.insert(0, ref_path)

    try:
        # Pre-create rfdetr package hierarchy (prevents __init__.py from running)
        for pkg, sub_path in [
            ("rfdetr", "rfdetr"),
            ("rfdetr.util", "rfdetr/util"),
            ("rfdetr.models", "rfdetr/models"),
            ("rfdetr.models.ops", "rfdetr/models/ops"),
            ("rfdetr.models.ops.functions", "rfdetr/models/ops/functions"),
            ("rfdetr.models.ops.modules", "rfdetr/models/ops/modules"),
        ]:
            mod = types.ModuleType(pkg)
            mod.__path__ = [ref_path + "/" + sub_path]
            mod.__package__ = pkg
            sys.modules[pkg] = mod

        # Mock rfdetr.util.logger with a real get_logger
        mock_logger = types.ModuleType("rfdetr.util.logger")
        mock_logger.get_logger = lambda name="rfdetr": logging.getLogger(name)
        sys.modules["rfdetr.util.logger"] = mock_logger

        # Mock backbone entirely (requires transformers/peft/DINOv2)
        sys.modules["rfdetr.models.backbone"] = MagicMock(__path__=[])

        # Load ops functions
        for fpath, mod_name, parent_name, attr_name in [
            (
                "rfdetr/models/ops/functions/ms_deform_attn_func.py",
                "rfdetr.models.ops.functions.ms_deform_attn_func",
                "rfdetr.models.ops.functions",
                "ms_deform_attn_core_pytorch",
            ),
            (
                "rfdetr/models/ops/modules/ms_deform_attn.py",
                "rfdetr.models.ops.modules.ms_deform_attn",
                "rfdetr.models.ops.modules",
                "MSDeformAttn",
            ),
        ]:
            spec = importlib.util.spec_from_file_location(mod_name, ref_path + "/" + fpath)
            mod = importlib.util.module_from_spec(spec)
            sys.modules[mod_name] = mod
            spec.loader.exec_module(mod)
            setattr(sys.modules[parent_name], attr_name, getattr(mod, attr_name))

        # Load util modules
        for fname, mod_name in [
            ("rfdetr/util/box_ops.py", "rfdetr.util.box_ops"),
            ("rfdetr/util/misc.py", "rfdetr.util.misc"),
        ]:
            spec = importlib.util.spec_from_file_location(mod_name, ref_path + "/" + fname)
            mod = importlib.util.module_from_spec(spec)
            sys.modules[mod_name] = mod
            spec.loader.exec_module(mod)

        # Now import the specific classes
        from rfdetr.models.lwdetr import LWDETR as RefRFLWDETR  # noqa: PLC0415, N811
        from rfdetr.models.lwdetr import PostProcess as RefRFPostProcess  # noqa: PLC0415
        from rfdetr.models.lwdetr import SetCriterion as RefRFSetCriterion  # noqa: PLC0415
        from rfdetr.models.matcher import HungarianMatcher as RefRFMatcher  # noqa: PLC0415
        from rfdetr.models.position_encoding import PositionEmbeddingSine as RefRFPosSine  # noqa: PLC0415
        from rfdetr.models.segmentation_head import SegmentationHead as RefRFSegHead  # noqa: PLC0415
        from rfdetr.models.transformer import Transformer as RefRFTransformer  # noqa: PLC0415
        from rfdetr.models.transformer import gen_encoder_output_proposals as ref_gen_proposals  # noqa: PLC0415
        from rfdetr.util.misc import NestedTensor as RefRFNestedTensor  # noqa: PLC0415

        _cached_ref_rfdetr = {
            "SetCriterion": RefRFSetCriterion,
            "PostProcess": RefRFPostProcess,
            "Matcher": RefRFMatcher,
            "PositionEmbeddingSine": RefRFPosSine,
            "SegmentationHead": RefRFSegHead,
            "NestedTensor": RefRFNestedTensor,
            "Transformer": RefRFTransformer,
            "LWDETR": RefRFLWDETR,
            "gen_encoder_output_proposals": ref_gen_proposals,
        }
        return _cached_ref_rfdetr
    except Exception:
        return None
    finally:
        if ref_path in sys.path:
            sys.path.remove(ref_path)
        _cleanup_rfdetr_modules()


@pytest.fixture()
def ref_rfdetr():
    result = _load_ref_rfdetr()
    if result is None:
        pytest.skip("Could not import reference RF-DETR modules")
    return result


# -------------------------------------------------------------------
# 8a. RF-DETR Position Encoding Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRPositionEncodingParity:
    def test_sine_encoding_matches_reference(self, ref_rfdetr):
        """adapted pe(x) vs reference pe(NestedTensor)."""
        from terratorch.models.detr.rfdetr.position_encoding import PositionEmbeddingSine  # noqa: PLC0415

        torch.manual_seed(42)
        adapted_pe = PositionEmbeddingSine(64, normalize=True)
        ref_pe = ref_rfdetr["PositionEmbeddingSine"](64, normalize=True)

        x = torch.randn(2, 128, 16, 16)
        mask = torch.zeros(2, 16, 16, dtype=torch.bool)

        adapted_out = adapted_pe(x)
        ref_out = ref_pe(ref_rfdetr["NestedTensor"](x, mask), align_dim_orders=False)

        assert torch.allclose(adapted_out, ref_out, atol=ATOL), (
            f"Max diff: {(adapted_out - ref_out).abs().max().item()}"
        )


# -------------------------------------------------------------------
# 8b. RF-DETR Matcher Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRMatcherParity:
    def test_hungarian_matcher_identical_indices(self, ref_rfdetr):
        """Same cost coefficients, same inputs -> same matching indices."""
        from terratorch.models.detr.rfdetr.matcher import HungarianMatcher  # noqa: PLC0415

        torch.manual_seed(42)
        adapted = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
        ref = ref_rfdetr["Matcher"](cost_class=1, cost_bbox=5, cost_giou=2)

        outputs = {
            "pred_logits": torch.randn(2, 10, 6),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
        }
        targets = [
            {"labels": torch.tensor([0]), "boxes": torch.tensor([[0.5, 0.5, 0.2, 0.2]])},
            {
                "labels": torch.tensor([1, 3]),
                "boxes": torch.tensor([[0.3, 0.3, 0.1, 0.1], [0.7, 0.7, 0.3, 0.3]]),
            },
        ]

        adapted_idx = adapted(outputs, targets)
        ref_idx = ref(outputs, targets)

        for (ai, aj), (ri, rj) in zip(adapted_idx, ref_idx, strict=True):
            assert torch.equal(ai, ri), f"Pred idx mismatch: {ai} vs {ri}"
            assert torch.equal(aj, rj), f"Tgt idx mismatch: {aj} vs {rj}"


# -------------------------------------------------------------------
# 8c. RF-DETR SetCriterion Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRSetCriterionParity:
    def test_loss_values_match(self, ref_rfdetr):
        """Focal-loss-based SetCriterion produces identical loss values."""
        from terratorch.models.detr.rfdetr.lwdetr import SetCriterion  # noqa: PLC0415
        from terratorch.models.detr.rfdetr.matcher import HungarianMatcher  # noqa: PLC0415

        torch.manual_seed(42)
        num_classes = 5
        matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
        weight_dict = {"loss_ce": 2, "loss_bbox": 5, "loss_giou": 2}
        losses = ["labels", "boxes", "cardinality"]

        adapted = SetCriterion(
            num_classes,
            matcher,
            weight_dict,
            focal_alpha=0.25,
            losses=losses,
        )
        ref = ref_rfdetr["SetCriterion"](
            num_classes,
            matcher,
            weight_dict,
            focal_alpha=0.25,
            losses=losses,
        )

        outputs = {
            "pred_logits": torch.randn(2, 10, num_classes),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
        }
        targets = _dummy_targets(2, num_classes)

        a_losses = adapted(outputs, targets)
        r_losses = ref(outputs, targets)

        for key in a_losses:
            assert key in r_losses, f"Missing key: {key}"
            assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
                f"{key}: {a_losses[key].item():.8f} vs {r_losses[key].item():.8f}"
            )


# -------------------------------------------------------------------
# 8d. RF-DETR PostProcess Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRPostProcessParity:
    def test_postprocess_output_matches(self, ref_rfdetr):
        """Top-k sigmoid post-processor: scores, labels, boxes all match."""
        from terratorch.models.detr.rfdetr.lwdetr import PostProcess  # noqa: PLC0415

        torch.manual_seed(42)
        adapted = PostProcess(num_select=10)
        ref = ref_rfdetr["PostProcess"](num_select=10)

        outputs = {
            "pred_logits": torch.randn(2, 20, 5),
            "pred_boxes": torch.rand(2, 20, 4).clamp(0.1, 0.9),
        }
        target_sizes = torch.tensor([[480, 640], [320, 480]])

        a_results = adapted(outputs, target_sizes)
        r_results = ref(outputs, target_sizes)

        for a, r in zip(a_results, r_results, strict=True):
            assert torch.allclose(a["scores"], r["scores"], atol=ATOL)
            assert torch.equal(a["labels"], r["labels"])
            assert torch.allclose(a["boxes"], r["boxes"], atol=ATOL)


# -------------------------------------------------------------------
# 8e. RF-DETR Segmentation Head Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRSegmentationHeadParity:
    def test_segmentation_head_forward_matches(self, ref_rfdetr):
        """Segmentation head forward with identical weights -> identical output."""
        from terratorch.models.detr.rfdetr.segmentation_head import SegmentationHead  # noqa: PLC0415

        torch.manual_seed(42)
        adapted = SegmentationHead(in_dim=64, num_blocks=2)
        ref = ref_rfdetr["SegmentationHead"](in_dim=64, num_blocks=2)
        ref.load_state_dict(adapted.state_dict())
        adapted.eval()
        ref.eval()

        spatial_features = torch.randn(2, 64, 16, 16)
        query_features = [torch.randn(2, 5, 64), torch.randn(2, 5, 64)]
        image_size = (64, 64)

        a_out = adapted(spatial_features, query_features, image_size)
        r_out = ref(spatial_features, query_features, image_size)

        assert len(a_out) == len(r_out)
        for a, r in zip(a_out, r_out, strict=True):
            assert torch.allclose(a, r, atol=ATOL), f"Max diff: {(a - r).abs().max().item()}"


# -------------------------------------------------------------------
# 8f. RF-DETR gen_encoder_output_proposals Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRGenProposalsParity:
    """Verify gen_encoder_output_proposals matches reference for both unsigmoid modes."""

    def test_proposals_with_mask(self, ref_rfdetr):
        """Proposals with padding mask: both unsigmoid=True and False."""
        from terratorch.models.detr.rfdetr.transformer import gen_encoder_output_proposals  # noqa: PLC0415

        torch.manual_seed(42)
        bs, d_model = 2, 64
        spatial_shapes = torch.tensor([[8, 8], [4, 4]], dtype=torch.long)
        total_len = 8 * 8 + 4 * 4  # 80
        memory = torch.randn(bs, total_len, d_model)
        mask = torch.zeros(bs, total_len, dtype=torch.bool)
        mask[:, 60:] = True  # mask some trailing positions

        for unsigmoid in [True, False]:
            a_mem, a_prop = gen_encoder_output_proposals(memory, mask, spatial_shapes, unsigmoid=unsigmoid)
            r_mem, r_prop = ref_rfdetr["gen_encoder_output_proposals"](
                memory, mask, spatial_shapes, unsigmoid=unsigmoid
            )
            assert torch.allclose(a_mem, r_mem, atol=ATOL), (
                f"Memory mismatch (unsigmoid={unsigmoid}): {(a_mem - r_mem).abs().max().item()}"
            )
            assert torch.allclose(a_prop, r_prop, atol=ATOL), (
                f"Proposals mismatch (unsigmoid={unsigmoid}): {(a_prop - r_prop).abs().max().item()}"
            )

    def test_proposals_no_mask(self, ref_rfdetr):
        """Proposals with mask=None (no padding)."""
        from terratorch.models.detr.rfdetr.transformer import gen_encoder_output_proposals  # noqa: PLC0415

        torch.manual_seed(42)
        bs, d_model = 2, 64
        spatial_shapes = torch.tensor([[8, 8], [4, 4]], dtype=torch.long)
        total_len = 8 * 8 + 4 * 4
        memory = torch.randn(bs, total_len, d_model)

        a_mem, a_prop = gen_encoder_output_proposals(memory, None, spatial_shapes)
        r_mem, r_prop = ref_rfdetr["gen_encoder_output_proposals"](memory, None, spatial_shapes)

        assert torch.allclose(a_mem, r_mem, atol=ATOL)
        assert torch.allclose(a_prop, r_prop, atol=ATOL)


# -------------------------------------------------------------------
# 8g. RF-DETR LWDETR End-to-End Forward Parity
# -------------------------------------------------------------------
class _MockRFDETRBackbone(nn.Module):
    """Mock backbone that returns pre-set features for reference LWDETR."""

    def __init__(self, srcs, masks, poss, nested_tensor_cls):
        super().__init__()
        self._srcs = srcs
        self._masks = masks
        self._poss = poss
        self._nt_cls = nested_tensor_cls

    def forward(self, _samples):
        features = [self._nt_cls(s, m) for s, m in zip(self._srcs, self._masks, strict=True)]
        return features, self._poss


def _compare_outputs(a_out, r_out, atol=ATOL):
    """Recursively compare LWDETR output dicts."""
    for key in a_out:
        assert key in r_out, f"Missing key in reference output: {key}"
        if key == "aux_outputs":
            assert len(a_out[key]) == len(r_out[key])
            for i, (a_aux, r_aux) in enumerate(zip(a_out[key], r_out[key], strict=True)):
                for ak in a_aux:
                    assert torch.allclose(a_aux[ak], r_aux[ak], atol=atol), (
                        f"aux_outputs[{i}][{ak}]: max diff {(a_aux[ak] - r_aux[ak]).abs().max().item()}"
                    )
        elif key == "enc_outputs":
            for ek in a_out[key]:
                assert torch.allclose(a_out[key][ek], r_out[key][ek], atol=atol), (
                    f"enc_outputs[{ek}]: max diff {(a_out[key][ek] - r_out[key][ek]).abs().max().item()}"
                )
        else:
            assert torch.allclose(a_out[key], r_out[key], atol=atol), (
                f"{key}: max diff {(a_out[key] - r_out[key]).abs().max().item()}"
            )


@requires_ref_rfdetr
class TestRFDETRLWDETRForwardParity:
    """Verify LWDETR forward produces identical output to reference (covers Transformer,
    gen_encoder_output_proposals, bbox_reparam, two-stage, iterative refinement)."""

    @staticmethod
    def _make_models(ref_rfdetr, *, two_stage, bbox_reparam, aux_loss):
        """Build adapted + reference LWDETR with matching weights and dummy features."""
        from terratorch.models.detr.rfdetr.lwdetr import LWDETR  # noqa: PLC0415
        from terratorch.models.detr.rfdetr.transformer import Transformer  # noqa: PLC0415

        d_model, num_queries, nhead, num_layers = 64, 10, 4, 2
        num_feature_levels, num_classes = 2, 5

        adapted_transformer = Transformer(
            d_model=d_model,
            sa_nhead=nhead,
            ca_nhead=nhead,
            num_queries=num_queries,
            num_decoder_layers=num_layers,
            dim_feedforward=128,
            dropout=0.0,
            return_intermediate_dec=True,
            two_stage=two_stage,
            num_feature_levels=num_feature_levels,
            dec_n_points=4,
            bbox_reparam=bbox_reparam,
        )
        adapted = LWDETR(
            adapted_transformer,
            segmentation_head=None,
            num_classes=num_classes,
            num_queries=num_queries,
            aux_loss=aux_loss,
            two_stage=two_stage,
            bbox_reparam=bbox_reparam,
        )

        # Dummy features
        torch.manual_seed(42)
        bs = 2
        srcs = [torch.randn(bs, d_model, 8, 8), torch.randn(bs, d_model, 4, 4)]
        masks = [torch.zeros(bs, 8, 8, dtype=torch.bool), torch.zeros(bs, 4, 4, dtype=torch.bool)]
        poss = [torch.randn(bs, d_model, 8, 8), torch.randn(bs, d_model, 4, 4)]

        mock_backbone = _MockRFDETRBackbone(srcs, masks, poss, ref_rfdetr["NestedTensor"])

        ref_transformer = ref_rfdetr["Transformer"](
            d_model=d_model,
            sa_nhead=nhead,
            ca_nhead=nhead,
            num_queries=num_queries,
            num_decoder_layers=num_layers,
            dim_feedforward=128,
            dropout=0.0,
            return_intermediate_dec=True,
            two_stage=two_stage,
            num_feature_levels=num_feature_levels,
            dec_n_points=4,
            bbox_reparam=bbox_reparam,
        )
        ref_model = ref_rfdetr["LWDETR"](
            mock_backbone,
            ref_transformer,
            segmentation_head=None,
            num_classes=num_classes,
            num_queries=num_queries,
            aux_loss=aux_loss,
            two_stage=two_stage,
            bbox_reparam=bbox_reparam,
        )

        # Copy all weights (mock backbone has no params, so state_dicts should match)
        ref_model.load_state_dict(adapted.state_dict(), strict=True)
        adapted.eval()
        ref_model.eval()

        return adapted, ref_model, srcs, masks, poss

    def test_forward_no_two_stage(self, ref_rfdetr):
        """LWDETR forward without two-stage: outputs match reference exactly."""
        adapted, ref_model, srcs, masks, poss = self._make_models(
            ref_rfdetr, two_stage=False, bbox_reparam=False, aux_loss=True
        )
        with torch.no_grad():
            a_out = adapted(srcs, masks, poss)
            # Reference forward takes NestedTensor; mock backbone ignores input
            dummy_nt = ref_rfdetr["NestedTensor"](torch.randn(2, 3, 64, 64), torch.zeros(2, 64, 64, dtype=torch.bool))
            r_out = ref_model(dummy_nt)

        _compare_outputs(a_out, r_out)

    def test_forward_two_stage_bbox_reparam(self, ref_rfdetr):
        """LWDETR forward with two_stage + bbox_reparam: outputs match reference exactly."""
        adapted, ref_model, srcs, masks, poss = self._make_models(
            ref_rfdetr, two_stage=True, bbox_reparam=True, aux_loss=True
        )
        with torch.no_grad():
            a_out = adapted(srcs, masks, poss)
            dummy_nt = ref_rfdetr["NestedTensor"](torch.randn(2, 3, 64, 64), torch.zeros(2, 64, 64, dtype=torch.bool))
            r_out = ref_model(dummy_nt)

        _compare_outputs(a_out, r_out)
        # Verify two-stage keys are present
        assert "enc_outputs" in a_out
        assert "pred_logits" in a_out["enc_outputs"]
        assert "pred_boxes" in a_out["enc_outputs"]

    def test_forward_two_stage_no_bbox_reparam(self, ref_rfdetr):
        """LWDETR forward with two_stage but no bbox_reparam."""
        adapted, ref_model, srcs, masks, poss = self._make_models(
            ref_rfdetr, two_stage=True, bbox_reparam=False, aux_loss=True
        )
        with torch.no_grad():
            a_out = adapted(srcs, masks, poss)
            dummy_nt = ref_rfdetr["NestedTensor"](torch.randn(2, 3, 64, 64), torch.zeros(2, 64, 64, dtype=torch.bool))
            r_out = ref_model(dummy_nt)

        _compare_outputs(a_out, r_out)


# -------------------------------------------------------------------
# 8h. RF-DETR SetCriterion ia_bce_loss + aux/enc Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRSetCriterionIaBceParity:
    """Verify SetCriterion with ia_bce_loss produces identical losses to reference."""

    def test_ia_bce_loss_values_match(self, ref_rfdetr):
        """ia_bce_loss-based SetCriterion: all loss values match reference."""
        from terratorch.models.detr.rfdetr.lwdetr import SetCriterion  # noqa: PLC0415
        from terratorch.models.detr.rfdetr.matcher import HungarianMatcher  # noqa: PLC0415

        torch.manual_seed(42)
        num_classes = 5
        matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
        weight_dict = {"loss_ce": 2, "loss_bbox": 5, "loss_giou": 2}
        losses = ["labels", "boxes", "cardinality"]

        adapted = SetCriterion(num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, ia_bce_loss=True)
        ref = ref_rfdetr["SetCriterion"](
            num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, ia_bce_loss=True
        )
        adapted.eval()
        ref.eval()

        outputs = {
            "pred_logits": torch.randn(2, 10, num_classes),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
        }
        targets = _dummy_targets(2, num_classes)

        a_losses = adapted(outputs, targets)
        r_losses = ref(outputs, targets)

        for key in a_losses:
            assert key in r_losses, f"Missing key in ref: {key}"
            assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
                f"{key}: adapted={a_losses[key].item():.8f} vs ref={r_losses[key].item():.8f}"
            )

    def test_with_aux_and_enc_outputs(self, ref_rfdetr):
        """SetCriterion with ia_bce_loss + aux_outputs + enc_outputs: all losses match."""
        from terratorch.models.detr.rfdetr.lwdetr import SetCriterion  # noqa: PLC0415
        from terratorch.models.detr.rfdetr.matcher import HungarianMatcher  # noqa: PLC0415

        torch.manual_seed(42)
        num_classes = 5
        matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
        weight_dict = {
            "loss_ce": 2,
            "loss_bbox": 5,
            "loss_giou": 2,
            "loss_ce_0": 2,
            "loss_bbox_0": 5,
            "loss_giou_0": 2,
            "loss_ce_enc": 2,
            "loss_bbox_enc": 5,
            "loss_giou_enc": 2,
        }
        losses = ["labels", "boxes", "cardinality"]

        adapted = SetCriterion(num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, ia_bce_loss=True)
        ref = ref_rfdetr["SetCriterion"](
            num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, ia_bce_loss=True
        )
        adapted.eval()
        ref.eval()

        outputs = {
            "pred_logits": torch.randn(2, 10, num_classes),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
            "aux_outputs": [
                {
                    "pred_logits": torch.randn(2, 10, num_classes),
                    "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
                },
            ],
            "enc_outputs": {
                "pred_logits": torch.randn(2, 10, num_classes),
                "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
            },
        }
        targets = _dummy_targets(2, num_classes)

        a_losses = adapted(outputs, targets)
        r_losses = ref(outputs, targets)

        for key in a_losses:
            assert key in r_losses, f"Missing key in ref: {key}"
            assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
                f"{key}: adapted={a_losses[key].item():.8f} vs ref={r_losses[key].item():.8f}"
            )

        # Verify auxiliary and encoder loss keys are present
        assert "loss_ce_0" in a_losses
        assert "loss_bbox_0" in a_losses
        assert "loss_ce_enc" in a_losses
        assert "loss_bbox_enc" in a_losses

    def test_varifocal_loss_values_match(self, ref_rfdetr):
        """Varifocal loss: all loss values match reference."""
        from terratorch.models.detr.rfdetr.lwdetr import SetCriterion  # noqa: PLC0415
        from terratorch.models.detr.rfdetr.matcher import HungarianMatcher  # noqa: PLC0415

        torch.manual_seed(42)
        num_classes = 5
        matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
        weight_dict = {"loss_ce": 2, "loss_bbox": 5, "loss_giou": 2}
        losses = ["labels", "boxes", "cardinality"]

        adapted = SetCriterion(
            num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, use_varifocal_loss=True
        )
        ref = ref_rfdetr["SetCriterion"](
            num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, use_varifocal_loss=True
        )
        adapted.eval()
        ref.eval()

        outputs = {
            "pred_logits": torch.randn(2, 10, num_classes),
            "pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
        }
        targets = _dummy_targets(2, num_classes)

        a_losses = adapted(outputs, targets)
        r_losses = ref(outputs, targets)

        for key in a_losses:
            assert key in r_losses, f"Missing key in ref: {key}"
            assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
                f"{key}: adapted={a_losses[key].item():.8f} vs ref={r_losses[key].item():.8f}"
            )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add DETR, Deformable DETR, and RF-DETR Object Detection / Instance Segmentation Network Architectures

1 participant