@@ -65,11 +65,15 @@ def ensure_models(self, model_type, device: torch.device, boost: bool):
6565 def load_models (self , model_type , device : torch .device , boost : bool ):
6666 """Ensure that the depth model is loaded"""
6767
68+ # TODO: we need to at least try to find models downloaded by other plugins (e.g. controlnet)
69+
6870 # model path and name
6971 # ZoeDepth and Marigold do not use this
7072 model_dir = "./models/midas"
7173 if model_type == 0 :
7274 model_dir = "./models/leres"
75+ if model_type == 11 :
76+ model_dir = "./models/depth_anything"
7377
7478 # create paths to model if not present
7579 os .makedirs (model_dir , exist_ok = True )
@@ -202,14 +206,31 @@ def load_models(self, model_type, device: torch.device, boost: bool):
202206 except :
203207 pass # run without xformers
204208
209+ elif model_type == 11 : # depth_anything
210+ from depth_anything .dpt import DPT_DINOv2
211+ # This will download the model... to some place
212+ model = (
213+ DPT_DINOv2 (
214+ encoder = "vitl" ,
215+ features = 256 ,
216+ out_channels = [256 , 512 , 1024 , 1024 ],
217+ localhub = False ,
218+ ).to (device ).eval ()
219+ )
220+ model_path = f"{ model_dir } /depth_anything_vitl14.pth"
221+ ensure_file_downloaded (model_path ,
222+ "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth" )
223+
224+ model .load_state_dict (torch .load (model_path ))
225+
205226 if model_type in range (0 , 10 ):
206227 model .eval () # prepare for evaluation
207228 # optimize
208229 if device == torch .device ("cuda" ):
209230 if model_type in [0 , 1 , 2 , 3 , 4 , 5 , 6 ]:
210231 model = model .to (memory_format = torch .channels_last ) # TODO: weird
211232 if not self .no_half :
212- if model_type in [1 , 2 , 3 , 4 , 5 , 6 ] and not boost : # TODO: zoedepth, too?
233+ if model_type in [1 , 2 , 3 , 4 , 5 , 6 ] and not boost : # TODO: zoedepth, Marigold and depth_anything, too?
213234 model = model .half ()
214235 model .to (device ) # to correct device
215236
@@ -250,7 +271,8 @@ def get_default_net_size(model_type):
250271 7 : [384 , 512 ],
251272 8 : [384 , 768 ],
252273 9 : [384 , 512 ],
253- 10 : [768 , 768 ]
274+ 10 : [768 , 768 ],
275+ 11 : [518 , 518 ]
254276 }
255277 if model_type in sizes :
256278 return sizes [model_type ]
@@ -307,6 +329,8 @@ def get_raw_prediction(self, input, net_width, net_height):
307329 elif self .depth_model_type == 10 :
308330 raw_prediction = estimatemarigold (img , self .depth_model , net_width , net_height ,
309331 self .marigold_ensembles , self .marigold_steps )
332+ elif self .depth_model_type == 11 :
333+ raw_prediction = estimatedepthanything (img , self .depth_model , net_width , net_height )
310334 else :
311335 raw_prediction = estimateboost (img , self .depth_model , self .depth_model_type , self .pix2pix_model ,
312336 self .boost_rmax )
@@ -414,6 +438,7 @@ def estimatemidas(img, model, w, h, resize_mode, normalization, no_half, precisi
414438# TODO: "h" is not used
415439def estimatemarigold (image , model , w , h , marigold_ensembles = 5 , marigold_steps = 12 ):
416440 # This hideous thing should be re-implemented once there is support from the upstream.
441+ # TODO: re-implement this hideous thing by using features from the upstream
417442 img = cv2 .cvtColor ((image * 255.0001 ).astype ('uint8' ), cv2 .COLOR_BGR2RGB )
418443 img = Image .fromarray (img )
419444 with torch .no_grad ():
@@ -423,6 +448,37 @@ def estimatemarigold(image, model, w, h, marigold_ensembles=5, marigold_steps=12
423448 return cv2 .resize (pipe_out .depth_np , (image .shape [:2 ][::- 1 ]), interpolation = cv2 .INTER_CUBIC )
424449
425450
451+ def estimatedepthanything (image , model , w , h ):
452+ from depth_anything .util .transform import Resize , NormalizeImage , PrepareForNet
453+ transform = Compose (
454+ [
455+ Resize (
456+ width = w // 14 * 14 ,
457+ height = h // 14 * 14 ,
458+ resize_target = False ,
459+ keep_aspect_ratio = True ,
460+ ensure_multiple_of = 14 ,
461+ resize_method = "lower_bound" ,
462+ image_interpolation_method = cv2 .INTER_CUBIC ,
463+ ),
464+ NormalizeImage (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
465+ PrepareForNet (),
466+ ]
467+ )
468+
469+ timage = transform ({"image" : image })["image" ]
470+ timage = torch .from_numpy (timage ).unsqueeze (0 ).to (next (model .parameters ()).device )
471+
472+ with torch .no_grad ():
473+ depth = model (timage )
474+ import torch .nn .functional as F
475+ depth = F .interpolate (
476+ depth [None ], (image .shape [0 ], image .shape [1 ]), mode = "bilinear" , align_corners = False
477+ )[0 , 0 ]
478+
479+ return depth .cpu ().numpy ()
480+
481+
426482class ImageandPatchs :
427483 def __init__ (self , root_dir , name , patchsinfo , rgb_image , scale = 1 ):
428484 self .root_dir = root_dir
@@ -640,13 +696,14 @@ def estimateboost(img, model, model_type, pix2pixmodel, whole_size_threshold):
640696
641697 if model_type == 0 : # leres
642698 net_receptive_field_size = 448
643- patch_netsize = 2 * net_receptive_field_size
644699 elif model_type == 1 : # dpt_beit_large_512
645700 net_receptive_field_size = 512
646- patch_netsize = 2 * net_receptive_field_size
701+ elif model_type == 11 : # depth_anything
702+ net_receptive_field_size = 518
647703 else : # other midas # TODO Marigold support
648704 net_receptive_field_size = 384
649- patch_netsize = 2 * net_receptive_field_size
705+ patch_netsize = 2 * net_receptive_field_size
706+ # Good luck trying to use zoedepth
650707
651708 gc .collect ()
652709 backbone .torch_gc ()
@@ -916,6 +973,8 @@ def singleestimate(img, msize, model, net_type):
916973 return estimateleres (img , model , msize , msize )
917974 elif net_type == 10 :
918975 return estimatemarigold (img , model , msize , msize )
976+ elif net_type == 11 :
977+ return estimatedepthanything (img , model , msize , msize )
919978 elif net_type >= 7 :
920979 # np to PIL
921980 return estimatezoedepth (Image .fromarray (np .uint8 (img * 255 )).convert ('RGB' ), model , msize , msize )
0 commit comments