Add DETR, Deformable DETR, and RF-DETR Object Detection / Instance Segmentation Network Architectures#1113
Open
josiahwsmith10 wants to merge 10 commits intoterrastackai:mainfrom
Open
Add DETR, Deformable DETR, and RF-DETR Object Detection / Instance Segmentation Network Architectures#1113josiahwsmith10 wants to merge 10 commits intoterrastackai:mainfrom
josiahwsmith10 wants to merge 10 commits intoterrastackai:mainfrom
Conversation
Copy verbatim reference implementations for provenance tracking: - detr.py, transformer.py, matcher.py, position_encoding.py from facebook/detr - deformable_detr.py, deformable_transformer.py from fundamentalvision/Deformable-DETR Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Modify the verbatim reference files from facebook/detr and
fundamentalvision/Deformable-DETR for use with TerraTorch's
ObjectDetectionModelFactory:
- Remove external dependencies (NestedTensor, util.box_ops,
util.misc, segmentation, backbone modules) and replace with
torchvision.ops equivalents
- Add terratorch_detr.py with TerraTorchDETR and
TerraTorchDeformableDETR wrapper classes that compose reference
components with BackboneWrapper
- Update __init__.py to re-export wrapper classes as DETR and
DeformableDETR
- Integrate with ObjectDetectionModelFactory ('detr' and
'deformable-detr' framework options)
- Add tests and example YAML config
- Ruff lint and format compliance
Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
… DETR and Deformable DETR Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Copy verbatim model files from https://github.com/roboflow/rf-detr (commit 82ad5d1) under Apache-2.0 License for adaptation into TerraTorch's ObjectDetectionModelFactory. Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Verbatim copy of facebook/detr models/segmentation.py containing mask prediction heads (MHAttentionMap, MaskHeadSmallConv), losses (dice_loss, sigmoid_focal_loss), and post-processing (PostProcessSegm, PostProcessPanoptic). Will be adapted for TerraTorch in a subsequent commit. Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Adapt reference segmentation.py for TerraTorch (remove DETRsegm wrapper, PostProcessPanoptic, NestedTensor; keep MHAttentionMap, MaskHeadSmallConv, dice_loss, sigmoid_focal_loss, PostProcessSegm). Add loss_masks to both SetCriterion classes with masks skipped in aux loss loop. Integrate mask head into TerraTorchDETR and TerraTorchDeformableDETR via masks=True parameter. Update ObjectDetectionTask for DETR segmentation metrics/masks. Add component, criterion, and end-to-end segmentation tests. Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
Integrate RF-DETR (Roboflow Detection Transformer) into TerraTorch's ObjectDetectionModelFactory alongside existing DETR and Deformable DETR. - Add TerraTorchRFDETR wrapper with decoder-only LW-DETR transformer, two-stage encoder proposals, bbox reparameterization, and pure-PyTorch multi-scale deformable attention (no CUDA extension required) - Register 'rf-detr' framework in ObjectDetectionModelFactory - Remove vendored DINOv2 backbone (replaced by TerraTorch BackboneWrapper) - Adapt all RF-DETR components for ruff compliance (F→f_nn, assert→raise, Optional→union syntax, naming conventions, etc.) - Add comprehensive parity tests proving numerical identity with reference: gen_encoder_output_proposals, full LWDETR forward (no two-stage, two-stage+bbox_reparam, two-stage without), SetCriterion with ia_bce_loss and varifocal_loss including aux/enc outputs, position encoding, matcher, PostProcess, and SegmentationHead - Add functional tests for factory build, train/eval forward, aux/enc loss, and individual component validation Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
…andards Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
…ultidimensional tensor indexing in SetCriterion (ia_bce / varifocal) in lwdetr.py Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
- Fix PyTorch 2.9 deprecation: use tuple indexing instead of list for multidimensional tensor indexing in SetCriterion (ia_bce / varifocal) - Add end-to-end Lightning training-loop smoke tests for DETR, Deformable DETR, and RF-DETR (train + val + predict on CPU with timm_resnet18 backbone and synthetic data) Signed-off-by: josiah.smith <josiahsmithphd@gmail.com>
0341fb5 to
6fd00fc
Compare
Author
|
Here's a test script to compare parity with original DETR implementations. """Reference parity tests: prove numerical identity with original DETR code.
These tests verify our adapted DETR, Deformable DETR, and RF-DETR code
produces identical outputs to the original reference implementations.
Tests skip gracefully when reference code is absent. To set up the
reference repos, clone them into /private/tmp/ (or update the path
constants below):
# DETR (Facebook Research)
git clone https://github.com/facebookresearch/detr.git /private/tmp/detr-ref
# Deformable DETR (Fundamentalvision)
git clone https://github.com/fundamentalvision/Deformable-DETR.git /private/tmp/deformable-detr-ref
# RF-DETR (Roboflow) — only the src/ subtree is needed
git clone https://github.com/roboflow/rf-detr.git /private/tmp/rf-detr-ref
# reference path points to /private/tmp/rf-detr-ref/src
Note: Deformable DETR tests also require the CUDA MSDeformAttn extension to
be compiled and importable (will not run on CPU-only machines).
"""
from __future__ import annotations
import sys
import types
from pathlib import Path
import pytest
import torch
from torch import nn
from terratorch.models.detr.detr import DETR, PostProcess, SetCriterion
from terratorch.models.detr.matcher import HungarianMatcher
from terratorch.models.detr.position_encoding import PositionEmbeddingSine
from terratorch.models.detr.transformer import Transformer
# Deformable DETR requires the CUDA MSDeformAttn extension at import time.
try:
from terratorch.models.detr.deformable_detr import PostProcess as DeformablePostProcess
from terratorch.models.detr.deformable_detr import SetCriterion as DeformableSetCriterion
_adapted_deformable_available = True
except ImportError:
_adapted_deformable_available = False
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_REF_DETR_PATH = "/private/tmp/detr-ref"
_REF_DEFORMABLE_DETR_PATH = "/private/tmp/deformable-detr-ref"
ATOL = 1e-6
_ref_detr_available = Path(_REF_DETR_PATH).is_dir()
_ref_deformable_detr_available = Path(_REF_DEFORMABLE_DETR_PATH).is_dir()
requires_ref_detr = pytest.mark.skipif(
not _ref_detr_available,
reason=f"Reference DETR not found at {_REF_DETR_PATH}",
)
requires_ref_deformable_detr = pytest.mark.skipif(
not (_ref_deformable_detr_available and _adapted_deformable_available),
reason="Reference Deformable DETR or CUDA MSDeformAttn extension not available",
)
# ---------------------------------------------------------------------------
# Reference module loading with sys.path isolation
# ---------------------------------------------------------------------------
def _cleanup_ref_modules():
"""Remove reference ``models`` and ``util`` packages from sys.modules."""
for key in list(sys.modules):
if key == "models" or key.startswith("models.") or key == "util" or key.startswith("util."):
del sys.modules[key]
_cached_ref_detr: dict | None = None
def _load_ref_detr():
"""Import reference DETR classes. Result is cached after first call."""
global _cached_ref_detr # noqa: PLW0603
if _cached_ref_detr is not None:
return _cached_ref_detr
_cleanup_ref_modules()
sys.path.insert(0, _REF_DETR_PATH)
try:
from models.detr import DETR as RefDETR # noqa: N811, PLC0415
from models.detr import PostProcess as RefPostProcess # noqa: PLC0415
from models.detr import SetCriterion as RefSetCriterion # noqa: PLC0415
from models.matcher import HungarianMatcher as RefMatcher # noqa: PLC0415
from models.position_encoding import PositionEmbeddingSine as RefPosSine # noqa: PLC0415
from models.transformer import Transformer as RefTransformer # noqa: PLC0415
from util.misc import NestedTensor # noqa: PLC0415
_cached_ref_detr = {
"DETR": RefDETR,
"SetCriterion": RefSetCriterion,
"PostProcess": RefPostProcess,
"Matcher": RefMatcher,
"Transformer": RefTransformer,
"PositionEmbeddingSine": RefPosSine,
"NestedTensor": NestedTensor,
}
return _cached_ref_detr
finally:
sys.path.remove(_REF_DETR_PATH)
_cleanup_ref_modules()
_cached_ref_deformable: dict | None = None
def _load_ref_deformable_detr():
"""Import reference Deformable DETR classes. Result is cached after first call."""
global _cached_ref_deformable # noqa: PLW0603
if _cached_ref_deformable is not None:
return _cached_ref_deformable
_cleanup_ref_modules()
sys.path.insert(0, _REF_DEFORMABLE_DETR_PATH)
try:
# Mock the CUDA extension (MSDeformAttn) so deformable_transformer.py
# can be imported without building the native extension.
mock_ops = types.ModuleType("models.ops")
mock_ops_modules = types.ModuleType("models.ops.modules")
mock_ops_modules.MSDeformAttn = type("MSDeformAttn", (nn.Module,), {})
mock_ops.modules = mock_ops_modules
sys.modules["models.ops"] = mock_ops
sys.modules["models.ops.modules"] = mock_ops_modules
from models.deformable_detr import PostProcess as RefDefPostProcess # noqa: PLC0415
from models.deformable_detr import SetCriterion as RefDefSetCriterion # noqa: PLC0415
_cached_ref_deformable = {
"SetCriterion": RefDefSetCriterion,
"PostProcess": RefDefPostProcess,
}
return _cached_ref_deformable
except Exception:
return None
finally:
sys.path.remove(_REF_DEFORMABLE_DETR_PATH)
_cleanup_ref_modules()
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def ref_detr():
return _load_ref_detr()
@pytest.fixture()
def ref_deformable_detr():
result = _load_ref_deformable_detr()
if result is None:
pytest.skip("Could not import reference Deformable DETR modules")
return result
# ---------------------------------------------------------------------------
# Mock backbones
# ---------------------------------------------------------------------------
class _SimpleBackbone(nn.Module):
"""Parameterless backbone stub exposing only ``num_channels``."""
def __init__(self, num_channels):
super().__init__()
self.num_channels = num_channels
class _MockRefBackbone(nn.Module):
"""Mock backbone that returns fixed features for the reference DETR.
The reference DETR calls ``self.backbone(samples)`` which must return
``([NestedTensor(src, mask)], [pos])``. We ignore the input samples
and return the pre-stored tensors so that the same ``(src, mask, pos)``
flows through both adapted and reference models.
"""
def __init__(self, num_channels, src, mask, pos, nested_tensor_cls):
super().__init__()
self.num_channels = num_channels
self._src = src
self._mask = mask
self._pos = pos
self._nt_cls = nested_tensor_cls
def forward(self, _samples):
return [self._nt_cls(self._src, self._mask)], [self._pos]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _dummy_targets(batch_size, num_classes):
"""Create deterministic dummy targets (caller must seed the RNG)."""
targets = []
for i in range(batch_size):
n = i + 1 # 1 object in first image, 2 in second, etc.
targets.append(
{
"labels": torch.randint(0, num_classes, (n,)),
"boxes": torch.rand(n, 4).clamp(0.1, 0.9),
}
)
return targets
# ===================================================================
# 1. Position Encoding
# ===================================================================
@requires_ref_detr
class TestPositionEncodingParity:
def test_sine_encoding_matches_reference(self, ref_detr):
"""Stateless sine PE: adapted pe(x) vs reference pe(NestedTensor)."""
torch.manual_seed(42)
adapted_pe = PositionEmbeddingSine(64, normalize=True)
ref_pe = ref_detr["PositionEmbeddingSine"](64, normalize=True)
x = torch.randn(2, 128, 16, 16)
mask = torch.zeros(2, 16, 16, dtype=torch.bool)
adapted_out = adapted_pe(x) # creates all-False mask internally
ref_out = ref_pe(ref_detr["NestedTensor"](x, mask))
assert torch.allclose(adapted_out, ref_out, atol=ATOL), (
f"Max diff: {(adapted_out - ref_out).abs().max().item()}"
)
# ===================================================================
# 2. Matcher
# ===================================================================
@requires_ref_detr
class TestMatcherParity:
def test_hungarian_matcher_identical_indices(self, ref_detr):
"""Same cost coefficients, same inputs → same matching indices."""
torch.manual_seed(42)
adapted = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
ref = ref_detr["Matcher"](cost_class=1, cost_bbox=5, cost_giou=2)
outputs = {
"pred_logits": torch.randn(2, 10, 6),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
}
targets = [
{"labels": torch.tensor([0]), "boxes": torch.tensor([[0.5, 0.5, 0.2, 0.2]])},
{
"labels": torch.tensor([1, 3]),
"boxes": torch.tensor([[0.3, 0.3, 0.1, 0.1], [0.7, 0.7, 0.3, 0.3]]),
},
]
adapted_idx = adapted(outputs, targets)
ref_idx = ref(outputs, targets)
for (ai, aj), (ri, rj) in zip(adapted_idx, ref_idx, strict=True):
assert torch.equal(ai, ri), f"Pred idx mismatch: {ai} vs {ri}"
assert torch.equal(aj, rj), f"Tgt idx mismatch: {aj} vs {rj}"
# ===================================================================
# 3. Transformer
# ===================================================================
@requires_ref_detr
class TestTransformerParity:
@staticmethod
def _build_pair(ref_detr, return_intermediate):
d_model, nhead = 64, 4
adapted = Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=128,
dropout=0.0,
return_intermediate_dec=return_intermediate,
)
ref = ref_detr["Transformer"](
d_model=d_model,
nhead=nhead,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=128,
dropout=0.0,
return_intermediate_dec=return_intermediate,
)
ref.load_state_dict(adapted.state_dict())
adapted.eval()
ref.eval()
return adapted, ref
def test_no_intermediate_output_matches(self, ref_detr):
"""return_intermediate_dec=False: hs [1,B,Q,D] and memory match."""
torch.manual_seed(42)
adapted, ref = self._build_pair(ref_detr, return_intermediate=False)
src = torch.randn(2, 64, 8, 8)
mask = torch.zeros(2, 8, 8, dtype=torch.bool)
query_embed = torch.randn(10, 64)
pos_embed = torch.randn(2, 64, 8, 8)
a_hs, a_mem = adapted(src, mask, query_embed, pos_embed)
r_hs, r_mem = ref(src, mask, query_embed, pos_embed)
assert torch.allclose(a_hs, r_hs, atol=ATOL), f"hs diff: {(a_hs - r_hs).abs().max()}"
assert torch.allclose(a_mem, r_mem, atol=ATOL), f"mem diff: {(a_mem - r_mem).abs().max()}"
def test_intermediate_output_matches(self, ref_detr):
"""return_intermediate_dec=True: all decoder layer outputs match."""
torch.manual_seed(42)
adapted, ref = self._build_pair(ref_detr, return_intermediate=True)
src = torch.randn(2, 64, 8, 8)
mask = torch.zeros(2, 8, 8, dtype=torch.bool)
query_embed = torch.randn(10, 64)
pos_embed = torch.randn(2, 64, 8, 8)
a_hs, a_mem = adapted(src, mask, query_embed, pos_embed)
r_hs, r_mem = ref(src, mask, query_embed, pos_embed)
assert torch.allclose(a_hs, r_hs, atol=ATOL), f"hs diff: {(a_hs - r_hs).abs().max()}"
assert torch.allclose(a_mem, r_mem, atol=ATOL), f"mem diff: {(a_mem - r_mem).abs().max()}"
# ===================================================================
# 4. SetCriterion (DETR)
# ===================================================================
@requires_ref_detr
class TestSetCriterionParity:
@staticmethod
def _build_pair(ref_detr, num_classes=5):
# Use the SAME matcher instance so matching indices are identical.
matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
weight_dict = {"loss_ce": 1, "loss_bbox": 5, "loss_giou": 2}
losses = ["labels", "boxes", "cardinality"]
adapted = SetCriterion(num_classes, matcher, weight_dict, eos_coef=0.1, losses=losses)
ref = ref_detr["SetCriterion"](num_classes, matcher, weight_dict, eos_coef=0.1, losses=losses)
# Copy the empty_weight buffer so both have identical class weights.
ref.load_state_dict(adapted.state_dict())
return adapted, ref
def test_loss_values_match(self, ref_detr):
"""All loss scalars match for a basic forward pass."""
torch.manual_seed(42)
adapted, ref = self._build_pair(ref_detr)
outputs = {
"pred_logits": torch.randn(2, 10, 6),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
}
targets = _dummy_targets(2, 5)
a_losses = adapted(outputs, targets)
r_losses = ref(outputs, targets)
for key in a_losses:
assert key in r_losses, f"Missing key: {key}"
assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
f"{key}: {a_losses[key].item():.8f} vs {r_losses[key].item():.8f}"
)
def test_loss_with_aux_outputs_match(self, ref_detr):
"""Auxiliary decoder losses (loss_ce_0, loss_bbox_0, ...) also match."""
torch.manual_seed(42)
adapted, ref = self._build_pair(ref_detr)
outputs = {
"pred_logits": torch.randn(2, 10, 6),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
"aux_outputs": [
{
"pred_logits": torch.randn(2, 10, 6),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
},
{
"pred_logits": torch.randn(2, 10, 6),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
},
],
}
targets = _dummy_targets(2, 5)
a_losses = adapted(outputs, targets)
r_losses = ref(outputs, targets)
for key in a_losses:
assert key in r_losses, f"Missing key: {key}"
assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
f"{key}: {a_losses[key].item():.8f} vs {r_losses[key].item():.8f}"
)
# ===================================================================
# 5. PostProcess (DETR)
# ===================================================================
@requires_ref_detr
class TestPostProcessParity:
def test_postprocess_output_matches(self, ref_detr):
"""Stateless post-processor: scores, labels, boxes all match."""
torch.manual_seed(42)
adapted = PostProcess()
ref = ref_detr["PostProcess"]()
outputs = {
"pred_logits": torch.randn(2, 10, 6),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
}
target_sizes = torch.tensor([[480, 640], [320, 480]])
a_results = adapted(outputs, target_sizes)
r_results = ref(outputs, target_sizes)
for a, r in zip(a_results, r_results, strict=True):
assert torch.allclose(a["scores"], r["scores"], atol=ATOL)
assert torch.equal(a["labels"], r["labels"])
assert torch.allclose(a["boxes"], r["boxes"], atol=ATOL)
# ===================================================================
# 6. DETR Model (inner class, bypassing backbone)
# ===================================================================
@requires_ref_detr
class TestDETRModelParity:
@staticmethod
def _build_pair(ref_detr, *, aux_loss, src, mask, pos):
"""Build adapted and reference DETR with shared weights."""
d_model, nhead = 64, 4
num_classes, num_queries = 5, 10
num_channels = src.shape[1]
nested_tensor_cls = ref_detr["NestedTensor"]
adapted_model = DETR(
_SimpleBackbone(num_channels),
Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=128,
dropout=0.0,
),
num_classes=num_classes,
num_queries=num_queries,
aux_loss=aux_loss,
)
ref_model = ref_detr["DETR"](
_MockRefBackbone(num_channels, src, mask, pos, nested_tensor_cls),
ref_detr["Transformer"](
d_model=d_model,
nhead=nhead,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=128,
dropout=0.0,
),
num_classes=num_classes,
num_queries=num_queries,
aux_loss=aux_loss,
)
# All non-backbone param names are identical → strict load succeeds.
ref_model.load_state_dict(adapted_model.state_dict())
adapted_model.eval()
ref_model.eval()
return adapted_model, ref_model, nested_tensor_cls
def test_detr_forward_no_aux_loss(self, ref_detr):
"""pred_logits and pred_boxes match without auxiliary loss."""
torch.manual_seed(42)
src = torch.randn(2, 32, 8, 8)
mask = torch.zeros(2, 8, 8, dtype=torch.bool)
pos = torch.randn(2, 64, 8, 8)
adapted, ref, nt_cls = self._build_pair(
ref_detr,
aux_loss=False,
src=src,
mask=mask,
pos=pos,
)
a_out = adapted(src, mask, pos)
# The NestedTensor content doesn't matter — mock backbone ignores it.
r_out = ref(nt_cls(torch.randn(2, 3, 8, 8), mask))
assert torch.allclose(a_out["pred_logits"], r_out["pred_logits"], atol=ATOL), (
f"logits diff: {(a_out['pred_logits'] - r_out['pred_logits']).abs().max()}"
)
assert torch.allclose(a_out["pred_boxes"], r_out["pred_boxes"], atol=ATOL), (
f"boxes diff: {(a_out['pred_boxes'] - r_out['pred_boxes']).abs().max()}"
)
assert "aux_outputs" not in a_out
assert "aux_outputs" not in r_out
def test_detr_forward_with_aux_loss(self, ref_detr):
"""pred_logits, pred_boxes, and all aux_outputs match."""
torch.manual_seed(42)
src = torch.randn(2, 32, 8, 8)
mask = torch.zeros(2, 8, 8, dtype=torch.bool)
pos = torch.randn(2, 64, 8, 8)
adapted, ref, nt_cls = self._build_pair(
ref_detr,
aux_loss=True,
src=src,
mask=mask,
pos=pos,
)
a_out = adapted(src, mask, pos)
r_out = ref(nt_cls(torch.randn(2, 3, 8, 8), mask))
assert torch.allclose(a_out["pred_logits"], r_out["pred_logits"], atol=ATOL)
assert torch.allclose(a_out["pred_boxes"], r_out["pred_boxes"], atol=ATOL)
assert len(a_out["aux_outputs"]) == len(r_out["aux_outputs"])
for i, (a, r) in enumerate(zip(a_out["aux_outputs"], r_out["aux_outputs"], strict=True)):
assert torch.allclose(a["pred_logits"], r["pred_logits"], atol=ATOL), (
f"aux[{i}] logits diff: {(a['pred_logits'] - r['pred_logits']).abs().max()}"
)
assert torch.allclose(a["pred_boxes"], r["pred_boxes"], atol=ATOL), (
f"aux[{i}] boxes diff: {(a['pred_boxes'] - r['pred_boxes']).abs().max()}"
)
# ===================================================================
# 7. Deformable DETR — SetCriterion and PostProcess
# ===================================================================
@requires_ref_deformable_detr
class TestDeformableSetCriterionParity:
def test_loss_values_match(self, ref_deformable_detr):
"""Focal-loss-based SetCriterion produces identical loss values."""
torch.manual_seed(42)
num_classes = 5
# Use the SAME matcher for both so matching indices are identical.
matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
weight_dict = {"loss_ce": 2, "loss_bbox": 5, "loss_giou": 2}
losses = ["labels", "boxes", "cardinality"]
adapted = DeformableSetCriterion(
num_classes,
matcher,
weight_dict,
losses=losses,
focal_alpha=0.25,
)
ref = ref_deformable_detr["SetCriterion"](
num_classes,
matcher,
weight_dict,
losses=losses,
focal_alpha=0.25,
)
outputs = {
"pred_logits": torch.randn(2, 10, num_classes),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
}
targets = _dummy_targets(2, num_classes)
a_losses = adapted(outputs, targets)
r_losses = ref(outputs, targets)
for key in a_losses:
assert key in r_losses, f"Missing key: {key}"
assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
f"{key}: {a_losses[key].item():.8f} vs {r_losses[key].item():.8f}"
)
@requires_ref_deformable_detr
class TestDeformablePostProcessParity:
def test_postprocess_output_matches(self, ref_deformable_detr):
"""Top-k sigmoid post-processor: scores, labels, boxes all match."""
torch.manual_seed(42)
adapted = DeformablePostProcess()
ref = ref_deformable_detr["PostProcess"]()
outputs = {
"pred_logits": torch.randn(2, 10, 5),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
}
target_sizes = torch.tensor([[480, 640], [320, 480]])
a_results = adapted(outputs, target_sizes)
r_results = ref(outputs, target_sizes)
for a, r in zip(a_results, r_results, strict=True):
assert torch.allclose(a["scores"], r["scores"], atol=ATOL)
assert torch.equal(a["labels"], r["labels"])
assert torch.allclose(a["boxes"], r["boxes"], atol=ATOL)
# ===================================================================
# 8. RF-DETR Parity Tests
# ===================================================================
_REF_RFDETR_PATH = "/private/tmp/rf-detr-ref/src"
_ref_rfdetr_available = Path(_REF_RFDETR_PATH).is_dir()
requires_ref_rfdetr = pytest.mark.skipif(
not _ref_rfdetr_available,
reason=f"Reference RF-DETR not found at {_REF_RFDETR_PATH}",
)
_cached_ref_rfdetr: dict | None = None
# RF-DETR reference modules to clean up after loading
_RFDETR_MODULE_PREFIXES = ("rfdetr",)
def _cleanup_rfdetr_modules():
"""Remove reference rfdetr modules from sys.modules."""
for key in list(sys.modules):
if key == "rfdetr" or key.startswith("rfdetr."):
del sys.modules[key]
def _load_ref_rfdetr():
"""Import reference RF-DETR classes. Result is cached after first call.
RF-DETR's reference code has heavy transitive dependencies (transformers,
peft, supervision, DINOv2 backbone). We mock these and manually set up
the package hierarchy to only load the model/loss/matcher files we need.
"""
import importlib # noqa: PLC0415
import logging # noqa: PLC0415
from unittest.mock import MagicMock # noqa: PLC0415
global _cached_ref_rfdetr # noqa: PLW0603
if _cached_ref_rfdetr is not None:
return _cached_ref_rfdetr
_cleanup_rfdetr_modules()
ref_path = _REF_RFDETR_PATH
sys.path.insert(0, ref_path)
try:
# Pre-create rfdetr package hierarchy (prevents __init__.py from running)
for pkg, sub_path in [
("rfdetr", "rfdetr"),
("rfdetr.util", "rfdetr/util"),
("rfdetr.models", "rfdetr/models"),
("rfdetr.models.ops", "rfdetr/models/ops"),
("rfdetr.models.ops.functions", "rfdetr/models/ops/functions"),
("rfdetr.models.ops.modules", "rfdetr/models/ops/modules"),
]:
mod = types.ModuleType(pkg)
mod.__path__ = [ref_path + "/" + sub_path]
mod.__package__ = pkg
sys.modules[pkg] = mod
# Mock rfdetr.util.logger with a real get_logger
mock_logger = types.ModuleType("rfdetr.util.logger")
mock_logger.get_logger = lambda name="rfdetr": logging.getLogger(name)
sys.modules["rfdetr.util.logger"] = mock_logger
# Mock backbone entirely (requires transformers/peft/DINOv2)
sys.modules["rfdetr.models.backbone"] = MagicMock(__path__=[])
# Load ops functions
for fpath, mod_name, parent_name, attr_name in [
(
"rfdetr/models/ops/functions/ms_deform_attn_func.py",
"rfdetr.models.ops.functions.ms_deform_attn_func",
"rfdetr.models.ops.functions",
"ms_deform_attn_core_pytorch",
),
(
"rfdetr/models/ops/modules/ms_deform_attn.py",
"rfdetr.models.ops.modules.ms_deform_attn",
"rfdetr.models.ops.modules",
"MSDeformAttn",
),
]:
spec = importlib.util.spec_from_file_location(mod_name, ref_path + "/" + fpath)
mod = importlib.util.module_from_spec(spec)
sys.modules[mod_name] = mod
spec.loader.exec_module(mod)
setattr(sys.modules[parent_name], attr_name, getattr(mod, attr_name))
# Load util modules
for fname, mod_name in [
("rfdetr/util/box_ops.py", "rfdetr.util.box_ops"),
("rfdetr/util/misc.py", "rfdetr.util.misc"),
]:
spec = importlib.util.spec_from_file_location(mod_name, ref_path + "/" + fname)
mod = importlib.util.module_from_spec(spec)
sys.modules[mod_name] = mod
spec.loader.exec_module(mod)
# Now import the specific classes
from rfdetr.models.lwdetr import LWDETR as RefRFLWDETR # noqa: PLC0415, N811
from rfdetr.models.lwdetr import PostProcess as RefRFPostProcess # noqa: PLC0415
from rfdetr.models.lwdetr import SetCriterion as RefRFSetCriterion # noqa: PLC0415
from rfdetr.models.matcher import HungarianMatcher as RefRFMatcher # noqa: PLC0415
from rfdetr.models.position_encoding import PositionEmbeddingSine as RefRFPosSine # noqa: PLC0415
from rfdetr.models.segmentation_head import SegmentationHead as RefRFSegHead # noqa: PLC0415
from rfdetr.models.transformer import Transformer as RefRFTransformer # noqa: PLC0415
from rfdetr.models.transformer import gen_encoder_output_proposals as ref_gen_proposals # noqa: PLC0415
from rfdetr.util.misc import NestedTensor as RefRFNestedTensor # noqa: PLC0415
_cached_ref_rfdetr = {
"SetCriterion": RefRFSetCriterion,
"PostProcess": RefRFPostProcess,
"Matcher": RefRFMatcher,
"PositionEmbeddingSine": RefRFPosSine,
"SegmentationHead": RefRFSegHead,
"NestedTensor": RefRFNestedTensor,
"Transformer": RefRFTransformer,
"LWDETR": RefRFLWDETR,
"gen_encoder_output_proposals": ref_gen_proposals,
}
return _cached_ref_rfdetr
except Exception:
return None
finally:
if ref_path in sys.path:
sys.path.remove(ref_path)
_cleanup_rfdetr_modules()
@pytest.fixture()
def ref_rfdetr():
result = _load_ref_rfdetr()
if result is None:
pytest.skip("Could not import reference RF-DETR modules")
return result
# -------------------------------------------------------------------
# 8a. RF-DETR Position Encoding Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRPositionEncodingParity:
def test_sine_encoding_matches_reference(self, ref_rfdetr):
"""adapted pe(x) vs reference pe(NestedTensor)."""
from terratorch.models.detr.rfdetr.position_encoding import PositionEmbeddingSine # noqa: PLC0415
torch.manual_seed(42)
adapted_pe = PositionEmbeddingSine(64, normalize=True)
ref_pe = ref_rfdetr["PositionEmbeddingSine"](64, normalize=True)
x = torch.randn(2, 128, 16, 16)
mask = torch.zeros(2, 16, 16, dtype=torch.bool)
adapted_out = adapted_pe(x)
ref_out = ref_pe(ref_rfdetr["NestedTensor"](x, mask), align_dim_orders=False)
assert torch.allclose(adapted_out, ref_out, atol=ATOL), (
f"Max diff: {(adapted_out - ref_out).abs().max().item()}"
)
# -------------------------------------------------------------------
# 8b. RF-DETR Matcher Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRMatcherParity:
def test_hungarian_matcher_identical_indices(self, ref_rfdetr):
"""Same cost coefficients, same inputs -> same matching indices."""
from terratorch.models.detr.rfdetr.matcher import HungarianMatcher # noqa: PLC0415
torch.manual_seed(42)
adapted = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
ref = ref_rfdetr["Matcher"](cost_class=1, cost_bbox=5, cost_giou=2)
outputs = {
"pred_logits": torch.randn(2, 10, 6),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
}
targets = [
{"labels": torch.tensor([0]), "boxes": torch.tensor([[0.5, 0.5, 0.2, 0.2]])},
{
"labels": torch.tensor([1, 3]),
"boxes": torch.tensor([[0.3, 0.3, 0.1, 0.1], [0.7, 0.7, 0.3, 0.3]]),
},
]
adapted_idx = adapted(outputs, targets)
ref_idx = ref(outputs, targets)
for (ai, aj), (ri, rj) in zip(adapted_idx, ref_idx, strict=True):
assert torch.equal(ai, ri), f"Pred idx mismatch: {ai} vs {ri}"
assert torch.equal(aj, rj), f"Tgt idx mismatch: {aj} vs {rj}"
# -------------------------------------------------------------------
# 8c. RF-DETR SetCriterion Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRSetCriterionParity:
def test_loss_values_match(self, ref_rfdetr):
"""Focal-loss-based SetCriterion produces identical loss values."""
from terratorch.models.detr.rfdetr.lwdetr import SetCriterion # noqa: PLC0415
from terratorch.models.detr.rfdetr.matcher import HungarianMatcher # noqa: PLC0415
torch.manual_seed(42)
num_classes = 5
matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
weight_dict = {"loss_ce": 2, "loss_bbox": 5, "loss_giou": 2}
losses = ["labels", "boxes", "cardinality"]
adapted = SetCriterion(
num_classes,
matcher,
weight_dict,
focal_alpha=0.25,
losses=losses,
)
ref = ref_rfdetr["SetCriterion"](
num_classes,
matcher,
weight_dict,
focal_alpha=0.25,
losses=losses,
)
outputs = {
"pred_logits": torch.randn(2, 10, num_classes),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9).sigmoid(),
}
targets = _dummy_targets(2, num_classes)
a_losses = adapted(outputs, targets)
r_losses = ref(outputs, targets)
for key in a_losses:
assert key in r_losses, f"Missing key: {key}"
assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
f"{key}: {a_losses[key].item():.8f} vs {r_losses[key].item():.8f}"
)
# -------------------------------------------------------------------
# 8d. RF-DETR PostProcess Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRPostProcessParity:
def test_postprocess_output_matches(self, ref_rfdetr):
"""Top-k sigmoid post-processor: scores, labels, boxes all match."""
from terratorch.models.detr.rfdetr.lwdetr import PostProcess # noqa: PLC0415
torch.manual_seed(42)
adapted = PostProcess(num_select=10)
ref = ref_rfdetr["PostProcess"](num_select=10)
outputs = {
"pred_logits": torch.randn(2, 20, 5),
"pred_boxes": torch.rand(2, 20, 4).clamp(0.1, 0.9),
}
target_sizes = torch.tensor([[480, 640], [320, 480]])
a_results = adapted(outputs, target_sizes)
r_results = ref(outputs, target_sizes)
for a, r in zip(a_results, r_results, strict=True):
assert torch.allclose(a["scores"], r["scores"], atol=ATOL)
assert torch.equal(a["labels"], r["labels"])
assert torch.allclose(a["boxes"], r["boxes"], atol=ATOL)
# -------------------------------------------------------------------
# 8e. RF-DETR Segmentation Head Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRSegmentationHeadParity:
def test_segmentation_head_forward_matches(self, ref_rfdetr):
"""Segmentation head forward with identical weights -> identical output."""
from terratorch.models.detr.rfdetr.segmentation_head import SegmentationHead # noqa: PLC0415
torch.manual_seed(42)
adapted = SegmentationHead(in_dim=64, num_blocks=2)
ref = ref_rfdetr["SegmentationHead"](in_dim=64, num_blocks=2)
ref.load_state_dict(adapted.state_dict())
adapted.eval()
ref.eval()
spatial_features = torch.randn(2, 64, 16, 16)
query_features = [torch.randn(2, 5, 64), torch.randn(2, 5, 64)]
image_size = (64, 64)
a_out = adapted(spatial_features, query_features, image_size)
r_out = ref(spatial_features, query_features, image_size)
assert len(a_out) == len(r_out)
for a, r in zip(a_out, r_out, strict=True):
assert torch.allclose(a, r, atol=ATOL), f"Max diff: {(a - r).abs().max().item()}"
# -------------------------------------------------------------------
# 8f. RF-DETR gen_encoder_output_proposals Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRGenProposalsParity:
"""Verify gen_encoder_output_proposals matches reference for both unsigmoid modes."""
def test_proposals_with_mask(self, ref_rfdetr):
"""Proposals with padding mask: both unsigmoid=True and False."""
from terratorch.models.detr.rfdetr.transformer import gen_encoder_output_proposals # noqa: PLC0415
torch.manual_seed(42)
bs, d_model = 2, 64
spatial_shapes = torch.tensor([[8, 8], [4, 4]], dtype=torch.long)
total_len = 8 * 8 + 4 * 4 # 80
memory = torch.randn(bs, total_len, d_model)
mask = torch.zeros(bs, total_len, dtype=torch.bool)
mask[:, 60:] = True # mask some trailing positions
for unsigmoid in [True, False]:
a_mem, a_prop = gen_encoder_output_proposals(memory, mask, spatial_shapes, unsigmoid=unsigmoid)
r_mem, r_prop = ref_rfdetr["gen_encoder_output_proposals"](
memory, mask, spatial_shapes, unsigmoid=unsigmoid
)
assert torch.allclose(a_mem, r_mem, atol=ATOL), (
f"Memory mismatch (unsigmoid={unsigmoid}): {(a_mem - r_mem).abs().max().item()}"
)
assert torch.allclose(a_prop, r_prop, atol=ATOL), (
f"Proposals mismatch (unsigmoid={unsigmoid}): {(a_prop - r_prop).abs().max().item()}"
)
def test_proposals_no_mask(self, ref_rfdetr):
"""Proposals with mask=None (no padding)."""
from terratorch.models.detr.rfdetr.transformer import gen_encoder_output_proposals # noqa: PLC0415
torch.manual_seed(42)
bs, d_model = 2, 64
spatial_shapes = torch.tensor([[8, 8], [4, 4]], dtype=torch.long)
total_len = 8 * 8 + 4 * 4
memory = torch.randn(bs, total_len, d_model)
a_mem, a_prop = gen_encoder_output_proposals(memory, None, spatial_shapes)
r_mem, r_prop = ref_rfdetr["gen_encoder_output_proposals"](memory, None, spatial_shapes)
assert torch.allclose(a_mem, r_mem, atol=ATOL)
assert torch.allclose(a_prop, r_prop, atol=ATOL)
# -------------------------------------------------------------------
# 8g. RF-DETR LWDETR End-to-End Forward Parity
# -------------------------------------------------------------------
class _MockRFDETRBackbone(nn.Module):
"""Mock backbone that returns pre-set features for reference LWDETR."""
def __init__(self, srcs, masks, poss, nested_tensor_cls):
super().__init__()
self._srcs = srcs
self._masks = masks
self._poss = poss
self._nt_cls = nested_tensor_cls
def forward(self, _samples):
features = [self._nt_cls(s, m) for s, m in zip(self._srcs, self._masks, strict=True)]
return features, self._poss
def _compare_outputs(a_out, r_out, atol=ATOL):
"""Recursively compare LWDETR output dicts."""
for key in a_out:
assert key in r_out, f"Missing key in reference output: {key}"
if key == "aux_outputs":
assert len(a_out[key]) == len(r_out[key])
for i, (a_aux, r_aux) in enumerate(zip(a_out[key], r_out[key], strict=True)):
for ak in a_aux:
assert torch.allclose(a_aux[ak], r_aux[ak], atol=atol), (
f"aux_outputs[{i}][{ak}]: max diff {(a_aux[ak] - r_aux[ak]).abs().max().item()}"
)
elif key == "enc_outputs":
for ek in a_out[key]:
assert torch.allclose(a_out[key][ek], r_out[key][ek], atol=atol), (
f"enc_outputs[{ek}]: max diff {(a_out[key][ek] - r_out[key][ek]).abs().max().item()}"
)
else:
assert torch.allclose(a_out[key], r_out[key], atol=atol), (
f"{key}: max diff {(a_out[key] - r_out[key]).abs().max().item()}"
)
@requires_ref_rfdetr
class TestRFDETRLWDETRForwardParity:
"""Verify LWDETR forward produces identical output to reference (covers Transformer,
gen_encoder_output_proposals, bbox_reparam, two-stage, iterative refinement)."""
@staticmethod
def _make_models(ref_rfdetr, *, two_stage, bbox_reparam, aux_loss):
"""Build adapted + reference LWDETR with matching weights and dummy features."""
from terratorch.models.detr.rfdetr.lwdetr import LWDETR # noqa: PLC0415
from terratorch.models.detr.rfdetr.transformer import Transformer # noqa: PLC0415
d_model, num_queries, nhead, num_layers = 64, 10, 4, 2
num_feature_levels, num_classes = 2, 5
adapted_transformer = Transformer(
d_model=d_model,
sa_nhead=nhead,
ca_nhead=nhead,
num_queries=num_queries,
num_decoder_layers=num_layers,
dim_feedforward=128,
dropout=0.0,
return_intermediate_dec=True,
two_stage=two_stage,
num_feature_levels=num_feature_levels,
dec_n_points=4,
bbox_reparam=bbox_reparam,
)
adapted = LWDETR(
adapted_transformer,
segmentation_head=None,
num_classes=num_classes,
num_queries=num_queries,
aux_loss=aux_loss,
two_stage=two_stage,
bbox_reparam=bbox_reparam,
)
# Dummy features
torch.manual_seed(42)
bs = 2
srcs = [torch.randn(bs, d_model, 8, 8), torch.randn(bs, d_model, 4, 4)]
masks = [torch.zeros(bs, 8, 8, dtype=torch.bool), torch.zeros(bs, 4, 4, dtype=torch.bool)]
poss = [torch.randn(bs, d_model, 8, 8), torch.randn(bs, d_model, 4, 4)]
mock_backbone = _MockRFDETRBackbone(srcs, masks, poss, ref_rfdetr["NestedTensor"])
ref_transformer = ref_rfdetr["Transformer"](
d_model=d_model,
sa_nhead=nhead,
ca_nhead=nhead,
num_queries=num_queries,
num_decoder_layers=num_layers,
dim_feedforward=128,
dropout=0.0,
return_intermediate_dec=True,
two_stage=two_stage,
num_feature_levels=num_feature_levels,
dec_n_points=4,
bbox_reparam=bbox_reparam,
)
ref_model = ref_rfdetr["LWDETR"](
mock_backbone,
ref_transformer,
segmentation_head=None,
num_classes=num_classes,
num_queries=num_queries,
aux_loss=aux_loss,
two_stage=two_stage,
bbox_reparam=bbox_reparam,
)
# Copy all weights (mock backbone has no params, so state_dicts should match)
ref_model.load_state_dict(adapted.state_dict(), strict=True)
adapted.eval()
ref_model.eval()
return adapted, ref_model, srcs, masks, poss
def test_forward_no_two_stage(self, ref_rfdetr):
"""LWDETR forward without two-stage: outputs match reference exactly."""
adapted, ref_model, srcs, masks, poss = self._make_models(
ref_rfdetr, two_stage=False, bbox_reparam=False, aux_loss=True
)
with torch.no_grad():
a_out = adapted(srcs, masks, poss)
# Reference forward takes NestedTensor; mock backbone ignores input
dummy_nt = ref_rfdetr["NestedTensor"](torch.randn(2, 3, 64, 64), torch.zeros(2, 64, 64, dtype=torch.bool))
r_out = ref_model(dummy_nt)
_compare_outputs(a_out, r_out)
def test_forward_two_stage_bbox_reparam(self, ref_rfdetr):
"""LWDETR forward with two_stage + bbox_reparam: outputs match reference exactly."""
adapted, ref_model, srcs, masks, poss = self._make_models(
ref_rfdetr, two_stage=True, bbox_reparam=True, aux_loss=True
)
with torch.no_grad():
a_out = adapted(srcs, masks, poss)
dummy_nt = ref_rfdetr["NestedTensor"](torch.randn(2, 3, 64, 64), torch.zeros(2, 64, 64, dtype=torch.bool))
r_out = ref_model(dummy_nt)
_compare_outputs(a_out, r_out)
# Verify two-stage keys are present
assert "enc_outputs" in a_out
assert "pred_logits" in a_out["enc_outputs"]
assert "pred_boxes" in a_out["enc_outputs"]
def test_forward_two_stage_no_bbox_reparam(self, ref_rfdetr):
"""LWDETR forward with two_stage but no bbox_reparam."""
adapted, ref_model, srcs, masks, poss = self._make_models(
ref_rfdetr, two_stage=True, bbox_reparam=False, aux_loss=True
)
with torch.no_grad():
a_out = adapted(srcs, masks, poss)
dummy_nt = ref_rfdetr["NestedTensor"](torch.randn(2, 3, 64, 64), torch.zeros(2, 64, 64, dtype=torch.bool))
r_out = ref_model(dummy_nt)
_compare_outputs(a_out, r_out)
# -------------------------------------------------------------------
# 8h. RF-DETR SetCriterion ia_bce_loss + aux/enc Parity
# -------------------------------------------------------------------
@requires_ref_rfdetr
class TestRFDETRSetCriterionIaBceParity:
"""Verify SetCriterion with ia_bce_loss produces identical losses to reference."""
def test_ia_bce_loss_values_match(self, ref_rfdetr):
"""ia_bce_loss-based SetCriterion: all loss values match reference."""
from terratorch.models.detr.rfdetr.lwdetr import SetCriterion # noqa: PLC0415
from terratorch.models.detr.rfdetr.matcher import HungarianMatcher # noqa: PLC0415
torch.manual_seed(42)
num_classes = 5
matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
weight_dict = {"loss_ce": 2, "loss_bbox": 5, "loss_giou": 2}
losses = ["labels", "boxes", "cardinality"]
adapted = SetCriterion(num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, ia_bce_loss=True)
ref = ref_rfdetr["SetCriterion"](
num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, ia_bce_loss=True
)
adapted.eval()
ref.eval()
outputs = {
"pred_logits": torch.randn(2, 10, num_classes),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
}
targets = _dummy_targets(2, num_classes)
a_losses = adapted(outputs, targets)
r_losses = ref(outputs, targets)
for key in a_losses:
assert key in r_losses, f"Missing key in ref: {key}"
assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
f"{key}: adapted={a_losses[key].item():.8f} vs ref={r_losses[key].item():.8f}"
)
def test_with_aux_and_enc_outputs(self, ref_rfdetr):
"""SetCriterion with ia_bce_loss + aux_outputs + enc_outputs: all losses match."""
from terratorch.models.detr.rfdetr.lwdetr import SetCriterion # noqa: PLC0415
from terratorch.models.detr.rfdetr.matcher import HungarianMatcher # noqa: PLC0415
torch.manual_seed(42)
num_classes = 5
matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
weight_dict = {
"loss_ce": 2,
"loss_bbox": 5,
"loss_giou": 2,
"loss_ce_0": 2,
"loss_bbox_0": 5,
"loss_giou_0": 2,
"loss_ce_enc": 2,
"loss_bbox_enc": 5,
"loss_giou_enc": 2,
}
losses = ["labels", "boxes", "cardinality"]
adapted = SetCriterion(num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, ia_bce_loss=True)
ref = ref_rfdetr["SetCriterion"](
num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, ia_bce_loss=True
)
adapted.eval()
ref.eval()
outputs = {
"pred_logits": torch.randn(2, 10, num_classes),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
"aux_outputs": [
{
"pred_logits": torch.randn(2, 10, num_classes),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
},
],
"enc_outputs": {
"pred_logits": torch.randn(2, 10, num_classes),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
},
}
targets = _dummy_targets(2, num_classes)
a_losses = adapted(outputs, targets)
r_losses = ref(outputs, targets)
for key in a_losses:
assert key in r_losses, f"Missing key in ref: {key}"
assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
f"{key}: adapted={a_losses[key].item():.8f} vs ref={r_losses[key].item():.8f}"
)
# Verify auxiliary and encoder loss keys are present
assert "loss_ce_0" in a_losses
assert "loss_bbox_0" in a_losses
assert "loss_ce_enc" in a_losses
assert "loss_bbox_enc" in a_losses
def test_varifocal_loss_values_match(self, ref_rfdetr):
"""Varifocal loss: all loss values match reference."""
from terratorch.models.detr.rfdetr.lwdetr import SetCriterion # noqa: PLC0415
from terratorch.models.detr.rfdetr.matcher import HungarianMatcher # noqa: PLC0415
torch.manual_seed(42)
num_classes = 5
matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
weight_dict = {"loss_ce": 2, "loss_bbox": 5, "loss_giou": 2}
losses = ["labels", "boxes", "cardinality"]
adapted = SetCriterion(
num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, use_varifocal_loss=True
)
ref = ref_rfdetr["SetCriterion"](
num_classes, matcher, weight_dict, focal_alpha=0.25, losses=losses, use_varifocal_loss=True
)
adapted.eval()
ref.eval()
outputs = {
"pred_logits": torch.randn(2, 10, num_classes),
"pred_boxes": torch.rand(2, 10, 4).clamp(0.1, 0.9),
}
targets = _dummy_targets(2, num_classes)
a_losses = adapted(outputs, targets)
r_losses = ref(outputs, targets)
for key in a_losses:
assert key in r_losses, f"Missing key in ref: {key}"
assert torch.allclose(a_losses[key], r_losses[key], atol=ATOL), (
f"{key}: adapted={a_losses[key].item():.8f} vs ref={r_losses[key].item():.8f}"
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #1111