@@ -516,16 +516,15 @@ def __init__(self, config: LightGlueConfig):
516
516
517
517
self .keypoint_detector = AutoModelForKeypointDetection .from_config (config .keypoint_detector_config )
518
518
519
+ self .keypoint_detector_descriptor_dim = config .keypoint_detector_config .descriptor_decoder_dim
519
520
self .descriptor_dim = config .descriptor_dim
520
521
self .num_layers = config .num_hidden_layers
521
522
self .filter_threshold = config .filter_threshold
522
523
self .depth_confidence = config .depth_confidence
523
524
self .width_confidence = config .width_confidence
524
525
525
- if self .descriptor_dim != config .keypoint_detector_config .descriptor_decoder_dim :
526
- self .input_projection = nn .Linear (
527
- config .keypoint_detector_config .descriptor_decoder_dim , self .descriptor_dim , bias = True
528
- )
526
+ if self .descriptor_dim != self .keypoint_detector_descriptor_dim :
527
+ self .input_projection = nn .Linear (self .keypoint_detector_descriptor_dim , self .descriptor_dim , bias = True )
529
528
else :
530
529
self .input_projection = nn .Identity ()
531
530
@@ -721,7 +720,7 @@ def _match_image_pair(
721
720
# (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
722
721
keypoints = keypoints .reshape (batch_size * 2 , initial_num_keypoints , 2 )
723
722
mask = mask .reshape (batch_size * 2 , initial_num_keypoints ) if mask is not None else None
724
- descriptors = descriptors .reshape (batch_size * 2 , initial_num_keypoints , self .descriptor_dim )
723
+ descriptors = descriptors .reshape (batch_size * 2 , initial_num_keypoints , self .keypoint_detector_descriptor_dim )
725
724
image_indices = torch .arange (batch_size * 2 , device = device )
726
725
# Keypoint normalization
727
726
keypoints = normalize_keypoints (keypoints , height , width )
@@ -892,7 +891,7 @@ def forward(
892
891
893
892
keypoints , _ , descriptors , mask = keypoint_detections [:4 ]
894
893
keypoints = keypoints .reshape (batch_size , 2 , - 1 , 2 ).to (pixel_values )
895
- descriptors = descriptors .reshape (batch_size , 2 , - 1 , self .descriptor_dim ).to (pixel_values )
894
+ descriptors = descriptors .reshape (batch_size , 2 , - 1 , self .keypoint_detector_descriptor_dim ).to (pixel_values )
896
895
mask = mask .reshape (batch_size , 2 , - 1 )
897
896
898
897
absolute_keypoints = keypoints .clone ()
0 commit comments