Skip to content

Commit 8f1c984

Browse files
committed
Add phi
1 parent e9720f7 commit 8f1c984

File tree

5 files changed

+99
-98
lines changed

5 files changed

+99
-98
lines changed

_doc/api/torch_models/llms.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ onnx_diagnostic.torch_models.llms
33
=================================
44

55
.. automodule:: onnx_diagnostic.torch_models.llms
6-
:members:
7-
:no-undoc-members:
6+
:members: get_tiny_llm
Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1 @@
1-
from typing import Any, Dict
2-
import torch
3-
import transformers
4-
from ..cache_helpers import make_dynamic_cache
5-
6-
7-
def get_tiny_llm(
8-
batch_size: int = 2,
9-
input_cache: bool = True,
10-
dynamic_rope: bool = False,
11-
**kwargs,
12-
) -> Dict[str, Any]:
13-
"""
14-
Gets a non initialized model.
15-
16-
:param batch_size: batch size
17-
:param input_cache: generate data for this iteration with or without cache
18-
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
19-
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
20-
:return: dictionary
21-
22-
See :ref:`l-plot-tiny-llm-export` for an example.
23-
"""
24-
config = {
25-
"architectures": ["LlamaForCausalLM"],
26-
"bos_token_id": 1,
27-
"eos_token_id": 2,
28-
"hidden_act": "silu",
29-
"hidden_size": 192,
30-
"initializer_range": 0.02,
31-
"intermediate_size": 1024,
32-
"max_position_embeddings": 1024,
33-
"model_type": "llama",
34-
"num_attention_heads": 2,
35-
"num_hidden_layers": 1,
36-
"num_key_value_heads": 1,
37-
"pretraining_tp": 1,
38-
"rms_norm_eps": 1e-05,
39-
"rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
40-
"tie_word_embeddings": False,
41-
"torch_dtype": "float32",
42-
"transformers_version": "4.31.0.dev0",
43-
"use_cache": True,
44-
"vocab_size": 32000,
45-
}
46-
47-
config.update(**kwargs)
48-
conf = transformers.LlamaConfig(**config)
49-
model = transformers.LlamaForCausalLM(conf)
50-
model.eval()
51-
52-
# now the inputs
53-
cache_last_dim = 96
54-
sequence_length = 30
55-
sequence_length2 = 3
56-
num_key_value_heads = 1
57-
max_token_id = config["vocab_size"] - 1
58-
n_layers = config["num_hidden_layers"]
59-
60-
batch = torch.export.Dim("batch", min=1, max=1024)
61-
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
62-
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
63-
64-
shapes = {
65-
"input_ids": {0: batch, 1: seq_length},
66-
"attention_mask": {
67-
0: batch,
68-
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
69-
},
70-
"past_key_values": [
71-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
72-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
73-
],
74-
}
75-
inputs = dict(
76-
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
77-
torch.int64
78-
),
79-
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
80-
torch.int64
81-
),
82-
past_key_values=make_dynamic_cache(
83-
[
84-
(
85-
torch.randn(
86-
batch_size, num_key_value_heads, sequence_length, cache_last_dim
87-
),
88-
torch.randn(
89-
batch_size, num_key_value_heads, sequence_length, cache_last_dim
90-
),
91-
)
92-
for i in range(n_layers)
93-
]
94-
),
95-
)
96-
return dict(inputs=inputs, model=model, dynamic_shapes=shapes)
1+
from .untrained.tiny_ll import get_tiny_llm

onnx_diagnostic/torch_models/untrained/__init__.py

Whitespace-only changes.
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from typing import Any, Dict
2+
import torch
3+
import transformers
4+
from ..cache_helpers import make_dynamic_cache
5+
6+
7+
def get_tiny_llm(
8+
batch_size: int = 2,
9+
input_cache: bool = True,
10+
dynamic_rope: bool = False,
11+
**kwargs,
12+
) -> Dict[str, Any]:
13+
"""
14+
Gets a non initialized model.
15+
16+
:param batch_size: batch size
17+
:param input_cache: generate data for this iteration with or without cache
18+
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
19+
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
20+
:return: dictionary
21+
22+
See :ref:`l-plot-tiny-llm-export` for an example.
23+
"""
24+
config = {
25+
"architectures": ["LlamaForCausalLM"],
26+
"bos_token_id": 1,
27+
"eos_token_id": 2,
28+
"hidden_act": "silu",
29+
"hidden_size": 192,
30+
"initializer_range": 0.02,
31+
"intermediate_size": 1024,
32+
"max_position_embeddings": 1024,
33+
"model_type": "llama",
34+
"num_attention_heads": 2,
35+
"num_hidden_layers": 1,
36+
"num_key_value_heads": 1,
37+
"pretraining_tp": 1,
38+
"rms_norm_eps": 1e-05,
39+
"rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
40+
"tie_word_embeddings": False,
41+
"torch_dtype": "float32",
42+
"transformers_version": "4.31.0.dev0",
43+
"use_cache": True,
44+
"vocab_size": 32000,
45+
}
46+
47+
config.update(**kwargs)
48+
conf = transformers.LlamaConfig(**config)
49+
model = transformers.LlamaForCausalLM(conf)
50+
model.eval()
51+
52+
# now the inputs
53+
cache_last_dim = 96
54+
sequence_length = 30
55+
sequence_length2 = 3
56+
num_key_value_heads = 1
57+
max_token_id = config["vocab_size"] - 1
58+
n_layers = config["num_hidden_layers"]
59+
60+
batch = torch.export.Dim("batch", min=1, max=1024)
61+
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
62+
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
63+
64+
shapes = {
65+
"input_ids": {0: batch, 1: seq_length},
66+
"attention_mask": {
67+
0: batch,
68+
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
69+
},
70+
"past_key_values": [
71+
[{0: batch, 2: cache_length} for _ in range(n_layers)],
72+
[{0: batch, 2: cache_length} for _ in range(n_layers)],
73+
],
74+
}
75+
inputs = dict(
76+
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
77+
torch.int64
78+
),
79+
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
80+
torch.int64
81+
),
82+
past_key_values=make_dynamic_cache(
83+
[
84+
(
85+
torch.randn(
86+
batch_size, num_key_value_heads, sequence_length, cache_last_dim
87+
),
88+
torch.randn(
89+
batch_size, num_key_value_heads, sequence_length, cache_last_dim
90+
),
91+
)
92+
for i in range(n_layers)
93+
]
94+
),
95+
)
96+
return dict(inputs=inputs, model=model, dynamic_shapes=shapes)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,4 @@ select = [
9898
"onnx_diagnostic/reference/__init__.py" = ["F401"]
9999
"onnx_diagnostic/torch_export_patches/__init__.py" = ["F401"]
100100
"onnx_diagnostic/torch_export_patches/patches/__init__.py" = ["F401"]
101+
"onnx_diagnostic/torch_models/llms.py" = ["F401"]

0 commit comments

Comments
 (0)