Skip to content

Commit 727dcac

Browse files
committed
add position_ids
1 parent 1e9c449 commit 727dcac

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

_doc/examples/plot_export_tiny_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def _forward_(*args, _f=None, **kwargs):
6767
)
6868

6969
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
70-
print(generated_text)
70+
print("-- prompt", prompt)
71+
print("-- answer", generated_text)
7172

7273
# %%
7374
# Let's restore the forward as it was.

onnx_diagnostic/torch_models/untrained/tiny_llm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
from typing import Any, Dict
22
import torch
33
import transformers
4-
from ..cache_helpers import make_dynamic_cache
4+
from ...cache_helpers import make_dynamic_cache
55

66

77
def get_tiny_llm(
88
batch_size: int = 2,
9-
input_cache: bool = True,
109
dynamic_rope: bool = False,
1110
**kwargs,
1211
) -> Dict[str, Any]:
1312
"""
1413
Gets a non initialized model.
1514
1615
:param batch_size: batch size
17-
:param input_cache: generate data for this iteration with or without cache
1816
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
1917
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
2018
:return: dictionary
@@ -63,6 +61,7 @@ def get_tiny_llm(
6361

6462
shapes = {
6563
"input_ids": {0: batch, 1: seq_length},
64+
"position_ids": {0: torch.export.Dim.DYNAMIC},
6665
"attention_mask": {
6766
0: batch,
6867
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
@@ -79,6 +78,9 @@ def get_tiny_llm(
7978
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
8079
torch.int64
8180
),
81+
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
82+
.to(torch.int64)
83+
.expand((batch_size, -1)),
8284
past_key_values=make_dynamic_cache(
8385
[
8486
(

0 commit comments

Comments
 (0)