Skip to content

Commit db01d5c

Browse files
committed
Support for Depth Anything
Thanks to huchenlei for his Depth Anything integration code for Mikubill/sd-webui-controlnet. This code was useful for this commit.
1 parent 0389f9e commit db01d5c

File tree

7 files changed

+104
-9
lines changed

7 files changed

+104
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
## Changelog
2+
### 0.4.6
3+
* Support for [Depth Anything](https://github.com/LiheYoung/Depth-Anything).
24
### 0.4.5
3-
* Support for [Marigold](https://marigoldmonodepth.github.io). [PR #385](https://github.com/thygate/stable-diffusion-webui-depthmap-script/pull/385).
5+
* Preliminary support for [Marigold](https://marigoldmonodepth.github.io). [PR #385](https://github.com/thygate/stable-diffusion-webui-depthmap-script/pull/385).
46
### 0.4.4
57
* Compatibility with stable-diffusion-webui 1.6.0
68
### 0.4.3 video processing tab

README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Feel free to comment and share in the discussions and submit issues.
9494

9595
## Acknowledgements
9696

97-
This project relies on code and information from following papers :
97+
This project relies on code and information from the following papers :
9898

9999
MiDaS :
100100

@@ -211,3 +211,16 @@ Marigold - Repurposing Diffusion-Based Image Generators for Monocular Depth Esti
211211
primaryClass={cs.CV}
212212
}
213213
```
214+
215+
Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data
216+
217+
```
218+
@misc{yang2024depth,
219+
title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data},
220+
author={Lihe Yang and Bingyi Kang and Zilong Huang and Xiaogang Xu and Jiashi Feng and Hengshuang Zhao},
221+
year={2024},
222+
eprint={2401.10891},
223+
archivePrefix={arXiv},
224+
primaryClass={cs.CV}
225+
}
226+
```

install.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
# Installs dependencies
2+
13
import launch
24
import platform
35
import sys
6+
import importlib.metadata
47

58
# TODO: some dependencies apparently being reinstalled on every run. Investigate and fix.
69

@@ -54,3 +57,20 @@ def ensure(module_name, min_version=None):
5457

5558
if platform.system() == 'Darwin':
5659
ensure('pyqt6')
60+
61+
# Depth Anything
62+
def get_installed_version(package: str):
63+
try:
64+
return importlib.metadata.version(package)
65+
except Exception:
66+
return None
67+
def try_install_from_wheel(pkg_name: str, wheel_url: str):
68+
if get_installed_version(pkg_name) is not None:
69+
return
70+
try:
71+
launch.run_pip(f"install {wheel_url}", f" {pkg_name} requirement for depthmap script")
72+
except Exception as e:
73+
print('Failed to install wheel for Depth Anything support. It won\'t work.')
74+
try_install_from_wheel(
75+
"depth_anything",
76+
"https://github.com/huchenlei/Depth-Anything/releases/download/v1.0.0/depth_anything-2024.1.22.0-py2.py3-none-any.whl")

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ networkx>=2.5
1919
diffusers>=0.20.1 # For Marigold
2020
pyqt5; sys_platform == 'windows'
2121
pyqt6; sys_platform != 'windows'
22+
https://github.com/huchenlei/Depth-Anything/releases/download/v1.0.0/depth_anything-2024.1.22.0-py2.py3-none-any.whl

src/common_ui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def main_ui_panel(is_depth_tab):
3838
'dpt_hybrid_384 (midas 3.0)',
3939
'midas_v21', 'midas_v21_small',
4040
'zoedepth_n (indoor)', 'zoedepth_k (outdoor)', 'zoedepth_nk',
41-
'Marigold v1'],
41+
'Marigold v1', 'depth_anything'],
4242
type="index")
4343
with gr.Box() as cur_option_root:
4444
inp -= 'depthmap_gen_row_1', cur_option_root

src/depthmap_generation.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,15 @@ def ensure_models(self, model_type, device: torch.device, boost: bool):
6565
def load_models(self, model_type, device: torch.device, boost: bool):
6666
"""Ensure that the depth model is loaded"""
6767

68+
# TODO: we need to at least try to find models downloaded by other plugins (e.g. controlnet)
69+
6870
# model path and name
6971
# ZoeDepth and Marigold do not use this
7072
model_dir = "./models/midas"
7173
if model_type == 0:
7274
model_dir = "./models/leres"
75+
if model_type == 11:
76+
model_dir = "./models/depth_anything"
7377

7478
# create paths to model if not present
7579
os.makedirs(model_dir, exist_ok=True)
@@ -202,14 +206,31 @@ def load_models(self, model_type, device: torch.device, boost: bool):
202206
except:
203207
pass # run without xformers
204208

209+
elif model_type == 11: # depth_anything
210+
from depth_anything.dpt import DPT_DINOv2
211+
# This will download the model... to some place
212+
model = (
213+
DPT_DINOv2(
214+
encoder="vitl",
215+
features=256,
216+
out_channels=[256, 512, 1024, 1024],
217+
localhub=False,
218+
).to(device).eval()
219+
)
220+
model_path = f"{model_dir}/depth_anything_vitl14.pth"
221+
ensure_file_downloaded(model_path,
222+
"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth")
223+
224+
model.load_state_dict(torch.load(model_path))
225+
205226
if model_type in range(0, 10):
206227
model.eval() # prepare for evaluation
207228
# optimize
208229
if device == torch.device("cuda"):
209230
if model_type in [0, 1, 2, 3, 4, 5, 6]:
210231
model = model.to(memory_format=torch.channels_last) # TODO: weird
211232
if not self.no_half:
212-
if model_type in [1, 2, 3, 4, 5, 6] and not boost: # TODO: zoedepth, too?
233+
if model_type in [1, 2, 3, 4, 5, 6] and not boost: # TODO: zoedepth, Marigold and depth_anything, too?
213234
model = model.half()
214235
model.to(device) # to correct device
215236

@@ -250,7 +271,8 @@ def get_default_net_size(model_type):
250271
7: [384, 512],
251272
8: [384, 768],
252273
9: [384, 512],
253-
10: [768, 768]
274+
10: [768, 768],
275+
11: [518, 518]
254276
}
255277
if model_type in sizes:
256278
return sizes[model_type]
@@ -307,6 +329,8 @@ def get_raw_prediction(self, input, net_width, net_height):
307329
elif self.depth_model_type == 10:
308330
raw_prediction = estimatemarigold(img, self.depth_model, net_width, net_height,
309331
self.marigold_ensembles, self.marigold_steps)
332+
elif self.depth_model_type == 11:
333+
raw_prediction = estimatedepthanything(img, self.depth_model, net_width, net_height)
310334
else:
311335
raw_prediction = estimateboost(img, self.depth_model, self.depth_model_type, self.pix2pix_model,
312336
self.boost_rmax)
@@ -414,6 +438,7 @@ def estimatemidas(img, model, w, h, resize_mode, normalization, no_half, precisi
414438
# TODO: "h" is not used
415439
def estimatemarigold(image, model, w, h, marigold_ensembles=5, marigold_steps=12):
416440
# This hideous thing should be re-implemented once there is support from the upstream.
441+
# TODO: re-implement this hideous thing by using features from the upstream
417442
img = cv2.cvtColor((image * 255.0001).astype('uint8'), cv2.COLOR_BGR2RGB)
418443
img = Image.fromarray(img)
419444
with torch.no_grad():
@@ -423,6 +448,37 @@ def estimatemarigold(image, model, w, h, marigold_ensembles=5, marigold_steps=12
423448
return cv2.resize(pipe_out.depth_np, (image.shape[:2][::-1]), interpolation=cv2.INTER_CUBIC)
424449

425450

451+
def estimatedepthanything(image, model, w, h):
452+
from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
453+
transform = Compose(
454+
[
455+
Resize(
456+
width=w // 14 * 14,
457+
height=h // 14 * 14,
458+
resize_target=False,
459+
keep_aspect_ratio=True,
460+
ensure_multiple_of=14,
461+
resize_method="lower_bound",
462+
image_interpolation_method=cv2.INTER_CUBIC,
463+
),
464+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
465+
PrepareForNet(),
466+
]
467+
)
468+
469+
timage = transform({"image": image})["image"]
470+
timage = torch.from_numpy(timage).unsqueeze(0).to(next(model.parameters()).device)
471+
472+
with torch.no_grad():
473+
depth = model(timage)
474+
import torch.nn.functional as F
475+
depth = F.interpolate(
476+
depth[None], (image.shape[0], image.shape[1]), mode="bilinear", align_corners=False
477+
)[0, 0]
478+
479+
return depth.cpu().numpy()
480+
481+
426482
class ImageandPatchs:
427483
def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1):
428484
self.root_dir = root_dir
@@ -640,13 +696,14 @@ def estimateboost(img, model, model_type, pix2pixmodel, whole_size_threshold):
640696

641697
if model_type == 0: # leres
642698
net_receptive_field_size = 448
643-
patch_netsize = 2 * net_receptive_field_size
644699
elif model_type == 1: # dpt_beit_large_512
645700
net_receptive_field_size = 512
646-
patch_netsize = 2 * net_receptive_field_size
701+
elif model_type == 11: # depth_anything
702+
net_receptive_field_size = 518
647703
else: # other midas # TODO Marigold support
648704
net_receptive_field_size = 384
649-
patch_netsize = 2 * net_receptive_field_size
705+
patch_netsize = 2 * net_receptive_field_size
706+
# Good luck trying to use zoedepth
650707

651708
gc.collect()
652709
backbone.torch_gc()
@@ -916,6 +973,8 @@ def singleestimate(img, msize, model, net_type):
916973
return estimateleres(img, model, msize, msize)
917974
elif net_type == 10:
918975
return estimatemarigold(img, model, msize, msize)
976+
elif net_type == 11:
977+
return estimatedepthanything(img, model, msize, msize)
919978
elif net_type >= 7:
920979
# np to PIL
921980
return estimatezoedepth(Image.fromarray(np.uint8(img * 255)).convert('RGB'), model, msize, msize)

src/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def get_commit_hash():
1515

1616
REPOSITORY_NAME = "stable-diffusion-webui-depthmap-script"
1717
SCRIPT_NAME = "DepthMap"
18-
SCRIPT_VERSION = "v0.4.5"
18+
SCRIPT_VERSION = "v0.4.6"
1919
SCRIPT_FULL_NAME = f"{SCRIPT_NAME} {SCRIPT_VERSION} ({get_commit_hash()})"
2020

2121

0 commit comments

Comments
 (0)