Skip to content

Commit 03dda7f

Browse files
committed
fix zero shot
1 parent 9eebdda commit 03dda7f

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`55`: add support for text-classification
78
* :pr:`54`: add support for fill-mask, refactoring
89
* :pr:`52`: add support for zero-shot-image-classification
910
* :pr:`50`: add support for onnxruntime fusion

onnx_diagnostic/tasks/zero_shot_image_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def get_inputs(
6262
), f"Unexpected type for input_height {type(input_height)}{config}"
6363

6464
batch = torch.export.Dim("batch", min=1, max=1024)
65-
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
65+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
6666
shapes = {
67-
"inputs_ids": {0: batch, 1: seq_length},
67+
"input_ids": {0: batch, 1: seq_length},
6868
"attention_mask": {0: batch, 1: seq_length},
6969
"pixel_values": {
7070
0: torch.export.Dim("batch_img", min=1, max=1024),

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def get_untrained_model_with_inputs(
105105
# outputs even with the same inputs in training mode.
106106
model.eval()
107107
res = fct(model, config, **kwargs)
108+
108109
res["input_kwargs"] = kwargs
109110
res["model_kwargs"] = mkwargs
110111

@@ -118,19 +119,24 @@ def get_untrained_model_with_inputs(
118119
update = {}
119120
for k, v in res.items():
120121
if k.startswith(("inputs", "dynamic_shapes")) and isinstance(v, dict):
121-
update[k] = filter_out_unexpected_inputs(model, v)
122+
update[k] = filter_out_unexpected_inputs(model, v, verbose=verbose)
122123
res.update(update)
123124
return res
124125

125126

126-
def filter_out_unexpected_inputs(model: torch.nn.Module, kwargs: Dict[str, Any]):
127+
def filter_out_unexpected_inputs(
128+
model: torch.nn.Module, kwargs: Dict[str, Any], verbose: int = 0
129+
):
127130
"""
128131
Removes input names in kwargs if no parameter names was found in ``model.forward``.
129132
"""
130133
sig = inspect.signature(model.forward)
131134
allowed = set(sig.parameters)
132-
kwargs = {k: v for k, v in kwargs.items() if k in allowed}
133-
return kwargs
135+
new_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
136+
diff = set(kwargs) - set(new_kwargs)
137+
if diff and verbose:
138+
print(f"[filter_out_unexpected_inputs] removed {diff}")
139+
return new_kwargs
134140

135141

136142
def compute_model_size(model: torch.nn.Module) -> Tuple[int, int]:

0 commit comments

Comments
 (0)