|
36 | 36 | rename_dynamic_dimensions, |
37 | 37 | rename_dynamic_expression, |
38 | 38 | ) |
39 | | -from onnx_diagnostic.cache_helpers import make_dynamic_cache |
| 39 | +from onnx_diagnostic.cache_helpers import make_dynamic_cache, make_encoder_decoder_cache |
40 | 40 |
|
41 | 41 | TFLOAT = onnx.TensorProto.FLOAT |
42 | 42 |
|
@@ -164,6 +164,8 @@ def test_flatten(self): |
164 | 164 | }, |
165 | 165 | ], |
166 | 166 | ) |
| 167 | + diff = max_diff(inputs, inputs, flatten=True, verbose=10) |
| 168 | + self.assertEqual(diff["abs"], 0) |
167 | 169 | flat = flatten_object(inputs, drop_keys=True) |
168 | 170 | diff = max_diff(inputs, flat, flatten=True, verbose=10) |
169 | 171 | self.assertEqual(diff["abs"], 0) |
@@ -442,6 +444,32 @@ def test_from_tensor(self): |
442 | 444 | convert_endian(proto) |
443 | 445 | dtype_to_tensor_dtype(dt) |
444 | 446 |
|
| 447 | + @hide_stdout() |
| 448 | + def test_flatten_encoder_decoder_cache(self): |
| 449 | + inputs = ( |
| 450 | + torch.rand((3, 4), dtype=torch.float16), |
| 451 | + [ |
| 452 | + torch.rand((5, 6), dtype=torch.float16), |
| 453 | + torch.rand((5, 6, 7), dtype=torch.float16), |
| 454 | + { |
| 455 | + "a": torch.rand((2,), dtype=torch.float16), |
| 456 | + "cache": make_encoder_decoder_cache( |
| 457 | + make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), |
| 458 | + make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), |
| 459 | + ), |
| 460 | + }, |
| 461 | + ], |
| 462 | + ) |
| 463 | + diff = max_diff(inputs, inputs, flatten=True, verbose=10) |
| 464 | + self.assertEqual(diff["abs"], 0) |
| 465 | + flat = flatten_object(inputs, drop_keys=True) |
| 466 | + diff = max_diff(inputs, flat, flatten=True, verbose=10) |
| 467 | + self.assertEqual(diff["abs"], 0) |
| 468 | + d = string_diff(diff) |
| 469 | + self.assertIsInstance(d, str) |
| 470 | + s = string_type(inputs) |
| 471 | + self.assertIn("EncoderDecoderCache", s) |
| 472 | + |
445 | 473 |
|
446 | 474 | if __name__ == "__main__": |
447 | 475 | unittest.main(verbosity=2) |
0 commit comments