Skip to content

Commit ca402e2

Browse files
authored
[LightGlue] Fixed attribute usage from descriptor_dim to keypoint_detector_descriptor_dim (huggingface#39021)
fix: fix descriptor dimension handling in LightGlue model
1 parent 48b6ef0 commit ca402e2

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

src/transformers/models/lightglue/modeling_lightglue.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -516,16 +516,15 @@ def __init__(self, config: LightGlueConfig):
516516

517517
self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
518518

519+
self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
519520
self.descriptor_dim = config.descriptor_dim
520521
self.num_layers = config.num_hidden_layers
521522
self.filter_threshold = config.filter_threshold
522523
self.depth_confidence = config.depth_confidence
523524
self.width_confidence = config.width_confidence
524525

525-
if self.descriptor_dim != config.keypoint_detector_config.descriptor_decoder_dim:
526-
self.input_projection = nn.Linear(
527-
config.keypoint_detector_config.descriptor_decoder_dim, self.descriptor_dim, bias=True
528-
)
526+
if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
527+
self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
529528
else:
530529
self.input_projection = nn.Identity()
531530

@@ -721,7 +720,7 @@ def _match_image_pair(
721720
# (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
722721
keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
723722
mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
724-
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.descriptor_dim)
723+
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
725724
image_indices = torch.arange(batch_size * 2, device=device)
726725
# Keypoint normalization
727726
keypoints = normalize_keypoints(keypoints, height, width)
@@ -892,7 +891,7 @@ def forward(
892891

893892
keypoints, _, descriptors, mask = keypoint_detections[:4]
894893
keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
895-
descriptors = descriptors.reshape(batch_size, 2, -1, self.descriptor_dim).to(pixel_values)
894+
descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
896895
mask = mask.reshape(batch_size, 2, -1)
897896

898897
absolute_keypoints = keypoints.clone()

src/transformers/models/lightglue/modular_lightglue.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -587,16 +587,15 @@ def __init__(self, config: LightGlueConfig):
587587

588588
self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
589589

590+
self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
590591
self.descriptor_dim = config.descriptor_dim
591592
self.num_layers = config.num_hidden_layers
592593
self.filter_threshold = config.filter_threshold
593594
self.depth_confidence = config.depth_confidence
594595
self.width_confidence = config.width_confidence
595596

596-
if self.descriptor_dim != config.keypoint_detector_config.descriptor_decoder_dim:
597-
self.input_projection = nn.Linear(
598-
config.keypoint_detector_config.descriptor_decoder_dim, self.descriptor_dim, bias=True
599-
)
597+
if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
598+
self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
600599
else:
601600
self.input_projection = nn.Identity()
602601

@@ -792,7 +791,7 @@ def _match_image_pair(
792791
# (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
793792
keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
794793
mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
795-
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.descriptor_dim)
794+
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
796795
image_indices = torch.arange(batch_size * 2, device=device)
797796
# Keypoint normalization
798797
keypoints = normalize_keypoints(keypoints, height, width)
@@ -963,7 +962,7 @@ def forward(
963962

964963
keypoints, _, descriptors, mask = keypoint_detections[:4]
965964
keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
966-
descriptors = descriptors.reshape(batch_size, 2, -1, self.descriptor_dim).to(pixel_values)
965+
descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
967966
mask = mask.reshape(batch_size, 2, -1)
968967

969968
absolute_keypoints = keypoints.clone()

0 commit comments

Comments
 (0)