Skip to content

Commit cdbc642

Browse files
committed
Better timm compatibility fix
1 parent 4c55ebd commit cdbc642

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

dmidas/backbones/beit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,15 @@ def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tenso
9595
"""
9696
Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
9797
"""
98+
if hasattr(self, 'drop_path1') and not hasattr(self, 'drop_path'):
99+
self.drop_path = self.drop_path1
98100
if self.gamma_1 is None:
99-
x = x + self.drop_path1(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
100-
x = x + self.drop_path2(self.mlp(self.norm2(x)))
101+
x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
102+
x = x + self.drop_path(self.mlp(self.norm2(x)))
101103
else:
102-
x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), resolution,
104+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution,
103105
shared_rel_pos_bias=shared_rel_pos_bias))
104-
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
106+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
105107
return x
106108

107109

dzoedepth/models/base_models/midas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_m
340340
print("img_size", img_size)
341341
# TODO: use locally-bundled midas
342342
# The repo should be changed back to isl-org/MiDaS once this MR lands
343-
midas = torch.hub.load("AyaanShah2204/MiDaS", midas_model_type,
343+
midas = torch.hub.load("semjon00/MiDaS", midas_model_type,
344344
pretrained=use_pretrained_midas, force_reload=force_reload)
345345
kwargs.update({'keep_aspect_ratio': force_keep_ar})
346346
midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features,

0 commit comments

Comments
 (0)