Skip to content

Commit e05a02a

Browse files
committed
fix two things
2 parents eab0a77 + 6ca5f73 commit e05a02a

File tree

4 files changed

+5
-4
lines changed

4 files changed

+5
-4
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Change Logs
55
+++++
66

77
* :pr:`287`: adds input ``'inputs_prompt'`` to test a LLM, meant to be used during validation
8+
* :pr:`288`: add .contiguous in torch.cond branch (attention patch for sdpa implementation)
89
* :pr:`286`: adds variable to track random nodes in models
910

1011
0.8.0

_doc/technical/plot_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def simple_generate_with_cache(
186186
# seen earlier for a torch model.
187187
# Let's ask first the function to return the session to avoid creating on the second call.
188188

189-
_res, session = onnx_generate(
189+
_res, session, _feeds = onnx_generate(
190190
model_name, inputs.input_ids, 2, max_new_tokens=2, return_session=True
191191
)
192192

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_onnx_generate(self):
4848
)
4949

5050
print("-- test_onnx_generate: generate")
51-
res, session = onnx_generate(
51+
res, session, _feeds = onnx_generate(
5252
model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True
5353
)
5454
n_inputs = input_ids.shape[1]

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,7 +1452,7 @@ def patched_sdpa_attention_forward(
14521452
scale=scaling,
14531453
is_causal=True,
14541454
**sdpa_kwargs,
1455-
),
1455+
).contiguous(),
14561456
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
14571457
query,
14581458
key,
@@ -1461,7 +1461,7 @@ def patched_sdpa_attention_forward(
14611461
scale=scaling,
14621462
is_causal=False,
14631463
**sdpa_kwargs,
1464-
),
1464+
).contiguous(),
14651465
[query, key, value],
14661466
)
14671467
attn_output = attn_output.transpose(1, 2).contiguous()

0 commit comments

Comments
 (0)