|
2 | 2 | bypass_export_some_errors, |
3 | 3 | register_additional_serialization_functions, |
4 | 4 | ) |
5 | | - |
6 | | -""" |
7 | | --- Missing dependencies -- |
8 | | -
|
9 | | -def is_torchdynamo_exporting() -> bool: |
10 | | - "Tells if torch is exporting a model." |
11 | | - import torch |
12 | | -
|
13 | | - if not hasattr(torch.compiler, "is_exporting"): |
14 | | - # torch.compiler.is_exporting requires torch>=2.7 |
15 | | - return False |
16 | | -
|
17 | | - try: |
18 | | - return torch.compiler.is_exporting() |
19 | | - except Exception: |
20 | | - try: |
21 | | - import torch._dynamo as dynamo |
22 | | -
|
23 | | - return dynamo.is_exporting() # type: ignore |
24 | | - except Exception: |
25 | | - return False |
26 | | -
|
27 | | -
|
28 | | -def string_type(anything, **args): |
29 | | - # too long |
30 | | - # from onnx_diagnostic.helpers import string_type |
31 | | - return str(anything) |
32 | | -
|
33 | | -
|
34 | | -if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): |
35 | | -
|
36 | | - def make_dynamic_cache( |
37 | | - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], |
38 | | - ) -> transformers.cache_utils.DynamicCache: |
39 | | - ''' |
40 | | - Creates an instance of :class:`transformers.cache_utils.DynamicCache`. |
41 | | - This version is valid for ``transformers >= 4.50``. |
42 | | -
|
43 | | - :param key_value_pairs: list of pairs of (key, values) |
44 | | - :return: :class:`transformers.cache_utils.DynamicCache` |
45 | | -
|
46 | | - Example: |
47 | | -
|
48 | | - :: |
49 | | -
|
50 | | - n_layers = 2 |
51 | | - bsize, nheads, slen, dim = 2, 4, 3, 7 |
52 | | -
|
53 | | - past_key_values = make_dynamic_cache( |
54 | | - [ |
55 | | - ( |
56 | | - torch.randn(bsize, nheads, slen, dim), |
57 | | - torch.randn(bsize, nheads, slen, dim), |
58 | | - ) |
59 | | - for i in range(n_layers) |
60 | | - ] |
61 | | - ) |
62 | | - print(string_type(past_key_values, with_shape=True)) |
63 | | - ''' |
64 | | - return transformers.cache_utils.DynamicCache(key_value_pairs) |
65 | | -
|
66 | | -else: |
67 | | -
|
68 | | - def make_dynamic_cache( |
69 | | - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], |
70 | | - ) -> transformers.cache_utils.DynamicCache: |
71 | | - ''' |
72 | | - Creates an instance of :class:`transformers.cache_utils.DynamicCache`. |
73 | | - This version is valid for ``transformers < 4.50``. |
74 | | -
|
75 | | - :param key_value_pairs: list of pairs of (key, values) |
76 | | - :return: :class:`transformers.cache_utils.DynamicCache` |
77 | | -
|
78 | | - Example: |
79 | | -
|
80 | | - :: |
81 | | -
|
82 | | - n_layers = 2 |
83 | | - bsize, nheads, slen, dim = 2, 4, 3, 7 |
84 | | -
|
85 | | - past_key_values = make_dynamic_cache( |
86 | | - [ |
87 | | - ( |
88 | | - torch.randn(bsize, nheads, slen, dim), |
89 | | - torch.randn(bsize, nheads, slen, dim), |
90 | | - ) |
91 | | - for i in range(n_layers) |
92 | | - ] |
93 | | - ) |
94 | | - print(string_type(past_key_values, with_shape=True)) |
95 | | - ''' |
96 | | - cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) |
97 | | - for i, (key, value) in enumerate(key_value_pairs): |
98 | | - cache.update(key, value, i) |
99 | | - return cache |
100 | | -
|
101 | | -
|
102 | | -def make_encoder_decoder_cache( |
103 | | - self_attention_cache: transformers.cache_utils.DynamicCache, |
104 | | - cross_attention_cache: transformers.cache_utils.DynamicCache, |
105 | | -) -> transformers.cache_utils.EncoderDecoderCache: |
106 | | - "Creates an EncoderDecoderCache." |
107 | | - return transformers.cache_utils.EncoderDecoderCache( |
108 | | - self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache |
109 | | - ) |
110 | | -
|
111 | | -""" |
0 commit comments