1- import os
1+ import os . path
22from typing import List , Optional , Tuple , Union
33
44import numpy as np
55import torch
66import torchvision
7+ from groundingdino .util .inference import load_model , predict
78from inference_exp import Detections
89from inference_exp .configuration import DEFAULT_DEVICE
910from inference_exp .entities import ColorFormat , ImageDimensions
10- from inference_exp .errors import MissingDependencyError , ModelRuntimeError
11+ from inference_exp .errors import ModelRuntimeError
1112from inference_exp .models .base .object_detection import (
1213 OpenVocabularyObjectDetectionModel ,
1314)
1415from inference_exp .models .common .model_packages import get_model_package_contents
15- from inference_exp .utils .download import download_files_to_directory
1616from torch import nn
1717from torchvision import transforms
1818from torchvision .ops import box_convert
1919
20- try :
21- from groundingdino .util .inference import load_model , predict
22- except ImportError as import_error :
23- raise MissingDependencyError (
24- message = f"Could not import GroundingDino model - this error means that some additional dependencies "
25- f"are not installed in the environment. If you run the `inference-exp` library directly in your Python "
26- f"program, make sure the following extras of the package are installed: `grounding-dino`."
27- f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
28- f"You can also contact Roboflow to get support." ,
29- help_url = "https://todo" ,
30- ) from import_error
31-
32-
33- DEFAULT_CONFIG_URL = "https://raw.githubusercontent.com/roboflow/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
34- DEFAULT_CONFIG_MD5 = "bdb07fc17b611d622633d133d2cf873a"
35-
3620
3721class GroundingDinoForObjectDetectionTorch (
3822 OpenVocabularyObjectDetectionModel [
@@ -50,23 +34,16 @@ def from_pretrained(
5034 ) -> "GroundingDinoForObjectDetectionTorch" :
5135 model_package_content = get_model_package_contents (
5236 model_package_dir = model_name_or_path ,
53- elements = ["groundingdino_swint_ogc .pth" ],
37+ elements = ["weights .pth" , "config.py " ],
5438 )
55- config_path = os .path .join (model_name_or_path , "GroundingDINO_SwinT_OGC.py" )
56- if not os .path .exists (config_path ):
57- download_files_to_directory (
58- target_dir = model_name_or_path ,
59- files_specs = [
60- (
61- "GroundingDINO_SwinT_OGC.py" ,
62- DEFAULT_CONFIG_URL ,
63- DEFAULT_CONFIG_MD5 ,
64- )
65- ],
66- )
39+ text_encoder_dir = os .path .join (model_name_or_path , "text_encoder" )
40+ loader_kwargs = {}
41+ if os .path .isdir (text_encoder_dir ):
42+ loader_kwargs ["text_encoder_type" ] = text_encoder_dir
6743 model = load_model (
68- model_config_path = config_path ,
69- model_checkpoint_path = model_package_content ["groundingdino_swint_ogc.pth" ],
44+ model_config_path = model_package_content ["config.py" ],
45+ model_checkpoint_path = model_package_content ["weights.pth" ],
46+ ** loader_kwargs ,
7047 ).to (device )
7148 return cls (model = model , device = device )
7249
@@ -176,19 +153,20 @@ def forward(
176153 text_threshold = conf_thresh
177154 caption = ". " .join (classes )
178155 all_boxes , all_logits , all_phrases = [], [], []
179- for image in pre_processed_images :
180- boxes , logits , phrases = predict (
181- model = self ._model ,
182- image = image ,
183- caption = caption ,
184- box_threshold = conf_thresh ,
185- text_threshold = text_threshold ,
186- device = self ._device ,
187- remove_combined = True ,
188- )
189- all_boxes .append (boxes )
190- all_logits .append (logits )
191- all_phrases .append (phrases )
156+ with torch .inference_mode ():
157+ for image in pre_processed_images :
158+ boxes , logits , phrases = predict (
159+ model = self ._model ,
160+ image = image ,
161+ caption = caption ,
162+ box_threshold = conf_thresh ,
163+ text_threshold = text_threshold ,
164+ device = self ._device ,
165+ remove_combined = True ,
166+ )
167+ all_boxes .append (boxes )
168+ all_logits .append (logits )
169+ all_phrases .append (phrases )
192170 return all_boxes , all_logits , all_phrases , classes
193171
194172 def post_process (
0 commit comments