Skip to content

Commit 16fddf2

Browse files
authored
Adds task object-detection (#79)
* add object-detection * update shapes
1 parent 1642132 commit 16fddf2

File tree

8 files changed

+415
-0
lines changed

8 files changed

+415
-0
lines changed

CHANGELOGS.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Change Logs
22
===========
33

4+
0.4.4
5+
+++++
6+
7+
* :pr:`79`: implements task ``object-detection``
8+
* :pr:`78`: uses *onnx-weekly* instead of *onnx* to avoid conflicts with *onnxscript*
9+
410
0.4.3
511
+++++
612

_doc/api/tasks/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Or:
4141
image_classification
4242
image_text_to_text
4343
mixture_of_expert
44+
object_detection
4445
sentence_similarity
4546
text_classification
4647
text_generation
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.tasks.object_detection
3+
======================================
4+
5+
.. automodule:: onnx_diagnostic.tasks.object_detection
6+
:members:
7+
:no-undoc-members:
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
4+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
5+
from onnx_diagnostic.torch_export_patches import torch_export_patches
6+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
7+
8+
9+
class TestTasks(ExtTestCase):
10+
@hide_stdout()
11+
def test_object_detection(self):
12+
mid = "hustvl/yolos-tiny"
13+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
14+
self.assertEqual(data["task"], "object-detection")
15+
self.assertIn((data["size"], data["n_weights"]), [(8160384, 2040096)])
16+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
17+
model(**inputs)
18+
model(**data["inputs2"])
19+
if not has_transformers("4.51.999"):
20+
raise unittest.SkipTest("Requires transformers>=4.52")
21+
with torch_export_patches(patch_transformers=True, verbose=10):
22+
torch.export.export(
23+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
24+
)
25+
26+
27+
if __name__ == "__main__":
28+
unittest.main(verbosity=2)

_unittests/ut_tasks/try_tasks.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,44 @@ def test_falcon_mamba_7b(self):
502502
for seq in sequences:
503503
print(f"Result: {seq['generated_text']}")
504504

505+
@never_test()
506+
def test_object_detection(self):
507+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k object_
508+
# https://huggingface.co/hustvl/yolos-tiny
509+
510+
from transformers import YolosImageProcessor, YolosForObjectDetection
511+
from PIL import Image
512+
import torch
513+
import requests
514+
515+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
516+
image = Image.open(requests.get(url, stream=True).raw)
517+
518+
model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny")
519+
image_processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")
520+
521+
inputs = image_processor(images=image, return_tensors="pt")
522+
print()
523+
print("-- inputs", string_type(inputs, with_shape=True, with_min_max=True))
524+
outputs = model(**inputs)
525+
print("-- outputs", string_type(outputs, with_shape=True, with_min_max=True))
526+
527+
# model predicts bounding boxes and corresponding COCO classes
528+
# logits = outputs.logits
529+
# bboxes = outputs.pred_boxes
530+
531+
# print results
532+
target_sizes = torch.tensor([image.size[::-1]])
533+
results = image_processor.post_process_object_detection(
534+
outputs, threshold=0.9, target_sizes=target_sizes
535+
)[0]
536+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
537+
box = [round(i, 2) for i in box.tolist()]
538+
print(
539+
f"Detected {model.config.id2label[label.item()]} with confidence "
540+
f"{round(score.item(), 3)} at location {box}"
541+
)
542+
505543

506544
if __name__ == "__main__":
507545
unittest.main(verbosity=2)

onnx_diagnostic/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
image_classification,
77
image_text_to_text,
88
mixture_of_expert,
9+
object_detection,
910
sentence_similarity,
1011
text_classification,
1112
text_generation,
@@ -20,6 +21,7 @@
2021
image_classification,
2122
image_text_to_text,
2223
mixture_of_expert,
24+
object_detection,
2325
sentence_similarity,
2426
text_classification,
2527
text_generation,
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from typing import Any, Callable, Dict, Optional, Tuple
2+
import torch
3+
from ..helpers.config_helper import update_config, check_hasattr
4+
5+
__TASK__ = "object-detection"
6+
7+
8+
def reduce_model_config(config: Any) -> Dict[str, Any]:
9+
"""Reduces a model size."""
10+
check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
11+
kwargs = dict(
12+
num_hidden_layers=(
13+
min(config.num_hidden_layers, 2)
14+
if hasattr(config, "num_hidden_layers")
15+
else len(config.hidden_sizes)
16+
)
17+
)
18+
update_config(config, kwargs)
19+
return kwargs
20+
21+
22+
def get_inputs(
23+
model: torch.nn.Module,
24+
config: Optional[Any],
25+
input_width: int,
26+
input_height: int,
27+
input_channels: int,
28+
batch_size: int = 2,
29+
dynamic_rope: bool = False,
30+
add_second_input: bool = False,
31+
**kwargs, # unused
32+
):
33+
"""
34+
Generates inputs for task ``object-detection``.
35+
36+
:param model: model to get the missing information
37+
:param config: configuration used to generate the model
38+
:param batch_size: batch size
39+
:param input_channels: input channel
40+
:param input_width: input width
41+
:param input_height: input height
42+
:return: dictionary
43+
"""
44+
assert isinstance(
45+
input_width, int
46+
), f"Unexpected type for input_width {type(input_width)}{config}"
47+
assert isinstance(
48+
input_width, int
49+
), f"Unexpected type for input_height {type(input_height)}{config}"
50+
51+
shapes = {
52+
"pixel_values": {
53+
0: torch.export.Dim("batch", min=1, max=1024),
54+
2: "width",
55+
3: "height",
56+
}
57+
}
58+
inputs = dict(
59+
pixel_values=torch.randn(batch_size, input_channels, input_width, input_height).clamp(
60+
-1, 1
61+
),
62+
)
63+
res = dict(inputs=inputs, dynamic_shapes=shapes)
64+
if add_second_input:
65+
res["inputs2"] = get_inputs(
66+
model=model,
67+
config=config,
68+
input_width=input_width + 1,
69+
input_height=input_height + 1,
70+
input_channels=input_channels,
71+
batch_size=batch_size + 1,
72+
dynamic_rope=dynamic_rope,
73+
**kwargs,
74+
)["inputs"]
75+
return res
76+
77+
78+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
79+
"""
80+
Inputs kwargs.
81+
82+
If the configuration is None, the function selects typical dimensions.
83+
"""
84+
if config is not None:
85+
if (
86+
hasattr(config, "model_type")
87+
and config.model_type == "timm_wrapper"
88+
and not hasattr(config, "num_hidden_layers")
89+
):
90+
input_size = config.pretrained_cfg["input_size"]
91+
kwargs = dict(
92+
batch_size=2,
93+
input_width=input_size[-2],
94+
input_height=input_size[-1],
95+
input_channels=input_size[-3],
96+
)
97+
return kwargs, get_inputs
98+
99+
check_hasattr(config, ("image_size", "architectures"), "num_channels")
100+
if config is not None:
101+
if hasattr(config, "image_size"):
102+
image_size = config.image_size
103+
else:
104+
assert config.architectures, f"empty architecture in {config}"
105+
from ..torch_models.hghub.hub_api import get_architecture_default_values
106+
107+
default_values = get_architecture_default_values(config.architectures[0])
108+
image_size = default_values["image_size"]
109+
if config is None or isinstance(image_size, int):
110+
kwargs = dict(
111+
batch_size=2,
112+
input_width=224 if config is None else image_size,
113+
input_height=224 if config is None else image_size,
114+
input_channels=3 if config is None else config.num_channels,
115+
)
116+
else:
117+
kwargs = dict(
118+
batch_size=2,
119+
input_width=config.image_size[0],
120+
input_height=config.image_size[1],
121+
input_channels=config.num_channels,
122+
)
123+
return kwargs, get_inputs

0 commit comments

Comments
 (0)