|
1 | 1 | from typing import Any, Dict, List, Set, Optional, Tuple, Union |
2 | | -from ..helpers import flatten_object |
3 | 2 | from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes |
4 | | -from ..helpers.fake_tensor_helper import make_fake |
5 | | -from .dynamic_shapes import ModelInputs, _flatten_dynamic_shapes |
| 3 | +from ..helpers.fake_tensor_helper import fake_reshape |
| 4 | +from .dynamic_shapes import ModelInputs |
6 | 5 |
|
7 | 6 |
|
8 | 7 | def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: |
@@ -204,10 +203,10 @@ def guess_dynamic_shapes_from_inputs( |
204 | 203 |
|
205 | 204 |
|
206 | 205 | def make_fake_with_dynamic_dimensions( |
207 | | - inputs: Any, |
| 206 | + x: Any, |
208 | 207 | dynamic_shapes: Any, |
209 | 208 | fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821 |
210 | | -) -> Any: # noqa: F821 |
| 209 | +) -> Tuple[Any, "FakeTensorMode"]: # noqa: F821 |
211 | 210 | """ |
212 | 211 | Replaces all tensors by fake tensor respecting the same |
213 | 212 | constraints as the following dynamic shapes. |
@@ -235,19 +234,81 @@ def make_fake_with_dynamic_dimensions( |
235 | 234 | ), |
236 | 235 | ] |
237 | 236 | ), |
238 | | - ) |
| 237 | + ), |
| 238 | + dynamic_shapes={ |
| 239 | + "input_ids": {0: "batch", 1: "seq_length"}, |
| 240 | + "attention_mask": {0: "batch", 1: "cache+seq"}, |
| 241 | + "position_ids": {0: "batch", 1: "seq_length"}, |
| 242 | + "past_key_values": [ |
| 243 | + [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}], |
| 244 | + [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}], |
| 245 | + ], |
| 246 | + }, |
239 | 247 | ) |
240 | 248 | print(inputs) |
241 | 249 | """ |
242 | | - flat_inputs = flatten_object(inputs, drop_keys=True) |
243 | | - flat_fake, fake_mode = make_fake(flat_inputs, fake_mode=fake_mode) |
244 | | - flat_ds = _flatten_dynamic_shapes(dynamic_shapes) |
245 | | - assert len(flat_inputs) == len(flat_ds), ( |
246 | | - f"Mismatch between the number of input tensor {len(flat_inputs)} " |
247 | | - f"and the number of dynamic_shapes {len(flat_ds)}" |
| 250 | + if x is None: |
| 251 | + return None, None |
| 252 | + if fake_mode is None: |
| 253 | + from torch.fx.experimental.symbolic_shapes import ShapeEnv |
| 254 | + from torch._subclasses.fake_tensor import FakeTensorMode |
| 255 | + |
| 256 | + shape_env = ShapeEnv() |
| 257 | + fake_mode = FakeTensorMode(shape_env=shape_env) |
| 258 | + |
| 259 | + if isinstance(x, (list, tuple)): |
| 260 | + return ( |
| 261 | + x.__class__( |
| 262 | + [ |
| 263 | + make_fake_with_dynamic_dimensions( |
| 264 | + i, fake_mode=fake_mode, dynamic_shapes=ds |
| 265 | + )[0] |
| 266 | + for i, ds in zip(x, dynamic_shapes) |
| 267 | + ] |
| 268 | + ), |
| 269 | + fake_mode, |
| 270 | + ) |
| 271 | + if isinstance(x, dict): |
| 272 | + return { |
| 273 | + k: make_fake_with_dynamic_dimensions( |
| 274 | + v, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[k] |
| 275 | + )[0] |
| 276 | + for k, v in x.items() |
| 277 | + }, fake_mode |
| 278 | + |
| 279 | + if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}: |
| 280 | + assert hasattr(x, "layers"), ( |
| 281 | + f"Une more recent version of transformers (>=4.55), " |
| 282 | + f"'layers' not found in class {type(x)}" |
| 283 | + ) |
| 284 | + assert ( |
| 285 | + isinstance(dynamic_shapes, list) and len(dynamic_shapes) == 2 |
| 286 | + ), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache" |
| 287 | + for il, layer in enumerate(x.layers): |
| 288 | + assert hasattr(layer, "keys") and hasattr(layer, "values"), ( |
| 289 | + f"Une more recent version of transformers (>=4.55), 'layers' " |
| 290 | + f"not found in class {type(layer)} ({dir(layer)})" |
| 291 | + ) |
| 292 | + layer.keys = make_fake_with_dynamic_dimensions( |
| 293 | + layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0][il] |
| 294 | + )[0] |
| 295 | + layer.values = make_fake_with_dynamic_dimensions( |
| 296 | + layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1][il] |
| 297 | + )[0] |
| 298 | + return x, fake_mode |
| 299 | + if x.__class__.__name__ == "EncoderDecoderCache": |
| 300 | + make_fake_with_dynamic_dimensions( |
| 301 | + x.self_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0] |
| 302 | + ) |
| 303 | + make_fake_with_dynamic_dimensions( |
| 304 | + x.cross_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1] |
| 305 | + ) |
| 306 | + return x, fake_mode |
| 307 | + if hasattr(x, "shape"): |
| 308 | + t = fake_reshape(x, dynamic_shapes, fake_mode=fake_mode) |
| 309 | + return t, fake_mode |
| 310 | + from . import string_type |
| 311 | + |
| 312 | + raise TypeError( |
| 313 | + f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}" |
248 | 314 | ) |
249 | | - flat_reshaped = [ |
250 | | - make_fake_with_dynamic_dimensions(t, sh, true_tensor=t, fake_mode=fake_mode) |
251 | | - for t, sh in zip(flat_inputs, flat_fake, flat_ds) |
252 | | - ] |
253 | | - return flat_reshaped |
|
0 commit comments