Skip to content

Commit 1fb18e2

Browse files
committed
doc
1 parent 53d0cc9 commit 1fb18e2

File tree

3 files changed

+55
-45
lines changed

3 files changed

+55
-45
lines changed

onnx_diagnostic/export/shape_helper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,10 @@ def make_fake_with_dynamic_dimensions(
215215
.. runpython::
216216
:showcode:
217217
218-
from onnx_diagnostic.export.dynamic_shapes import make_fake_with_dynamic_dimensions
218+
import pprint
219+
import torch
220+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
221+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
219222
220223
inputs, _ = make_fake_with_dynamic_dimensions(
221224
dict(
@@ -245,7 +248,7 @@ def make_fake_with_dynamic_dimensions(
245248
],
246249
},
247250
)
248-
print(inputs)
251+
pprint.pprint(inputs)
249252
"""
250253
if x is None:
251254
return None, None

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,15 @@ def make_dynamic_cache(
168168
]
169169
)
170170
print(string_type(past_key_values, with_shape=True))
171+
172+
The function is fully able to handle ``FakeTensor`` with dynamic dimensions if
173+
``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
174+
are supported.
171175
"""
172176
if (
173177
key_value_pairs
174178
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
175-
and pv.Version(transformers.__version__) >= pv.Version("4.55")
179+
and pv.Version(transformers.__version__) >= pv.Version("4.56")
176180
):
177181
cache = transformers.cache_utils.DynamicCache()
178182
cache.layers.extend(
@@ -516,51 +520,51 @@ def make_hybrid_cache(
516520
517521
.. code-block:: python
518522
519-
self.max_cache_len = (
520-
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
523+
self.max_cache_len = (
524+
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
521525
522-
# Sliding layers can't be larger than the overall max cache len
523-
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
524-
self.max_batch_size = max_batch_size
526+
# Sliding layers can't be larger than the overall max cache len
527+
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
528+
self.max_batch_size = max_batch_size
525529
526-
self.head_dim = (
527-
config.head_dim if hasattr(config, "head_dim")
528-
else config.hidden_size // config.num_attention_heads
529-
)
530+
self.head_dim = (
531+
config.head_dim if hasattr(config, "head_dim")
532+
else config.hidden_size // config.num_attention_heads
533+
)
530534
531-
self._dtype = dtype
532-
self.num_key_value_heads = (
533-
config.num_attention_heads
534-
if getattr(config, "num_key_value_heads", None) is None
535-
else config.num_key_value_heads
536-
)
535+
self._dtype = dtype
536+
self.num_key_value_heads = (
537+
config.num_attention_heads
538+
if getattr(config, "num_key_value_heads", None) is None
539+
else config.num_key_value_heads
540+
)
537541
538-
# If the attribute does not exist in the config, fallback to a simple StaticCache
539-
if hasattr(config, "layer_types"):
540-
self.is_sliding = [
541-
layer_type != "full_attention" for layer_type in config.layer_types]
542-
else:
543-
self.is_sliding = [False] * config.num_hidden_layers
544-
545-
self.key_cache: list[torch.Tensor] = []
546-
self.value_cache: list[torch.Tensor] = []
547-
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
548-
self.max_cache_len, self.head_dim)
549-
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
550-
self.sliding_window_len, self.head_dim)
551-
self.sliding_window = min(config.sliding_window, max_cache_len)
552-
device = torch.device(device) if device is not None else None
553-
for i in range(config.num_hidden_layers):
554-
layer_device = layer_device_map[i] if layer_device_map is not None else device
555-
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
556-
new_layer_key_cache = torch.zeros(
557-
cache_shape, dtype=self._dtype, device=layer_device)
558-
new_layer_value_cache = torch.zeros(
559-
cache_shape, dtype=self._dtype, device=layer_device)
560-
torch._dynamo.mark_static_address(new_layer_key_cache)
561-
torch._dynamo.mark_static_address(new_layer_value_cache)
562-
self.key_cache.append(new_layer_key_cache)
563-
self.value_cache.append(new_layer_value_cache)
542+
# If the attribute does not exist in the config, fallback to a simple StaticCache
543+
if hasattr(config, "layer_types"):
544+
self.is_sliding = [
545+
layer_type != "full_attention" for layer_type in config.layer_types]
546+
else:
547+
self.is_sliding = [False] * config.num_hidden_layers
548+
549+
self.key_cache: list[torch.Tensor] = []
550+
self.value_cache: list[torch.Tensor] = []
551+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
552+
self.max_cache_len, self.head_dim)
553+
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
554+
self.sliding_window_len, self.head_dim)
555+
self.sliding_window = min(config.sliding_window, max_cache_len)
556+
device = torch.device(device) if device is not None else None
557+
for i in range(config.num_hidden_layers):
558+
layer_device = layer_device_map[i] if layer_device_map is not None else device
559+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
560+
new_layer_key_cache = torch.zeros(
561+
cache_shape, dtype=self._dtype, device=layer_device)
562+
new_layer_value_cache = torch.zeros(
563+
cache_shape, dtype=self._dtype, device=layer_device)
564+
torch._dynamo.mark_static_address(new_layer_key_cache)
565+
torch._dynamo.mark_static_address(new_layer_value_cache)
566+
self.key_cache.append(new_layer_key_cache)
567+
self.value_cache.append(new_layer_value_cache)
564568
"""
565569
layer_types = None
566570
if key_value_pairs:

onnx_diagnostic/helpers/fake_tensor_helper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def make_fake(
8686
.. runpython::
8787
:showcode:
8888
89+
import pprint
90+
import torch
91+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
8992
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
9093
9194
inputs, _ = make_fake(
@@ -107,7 +110,7 @@ def make_fake(
107110
),
108111
)
109112
)
110-
print(inputs)
113+
pprint.pprint(inputs)
111114
"""
112115
if x is None:
113116
return None, None

0 commit comments

Comments
 (0)