1+ import copy
12import unittest
23import torch
34from transformers .cache_utils import DynamicCache
@@ -22,18 +23,21 @@ def test_get_tiny_llm(self):
2223 def test_export_tiny_llm_1 (self ):
2324 data = get_tiny_llm ()
2425 model , inputs = data ["model" ], data ["inputs" ]
26+ expected = model (** copy .deepcopy (inputs ))
2527 self .assertEqual (
2628 {"attention_mask" , "past_key_values" , "input_ids" , "position_ids" }, set (inputs )
2729 )
2830 ep = torch .export .export (
29- model , (), kwargs = inputs , dynamic_shapes = data ["dynamic_shapes" ]
31+ model , (), kwargs = copy . deepcopy ( inputs ) , dynamic_shapes = data ["dynamic_shapes" ]
3032 )
31- assert ep
33+ got = ep .module ()(** inputs )
34+ self .assertEqualArrayAny (expected , got )
3235
3336 @ignore_warnings (UserWarning )
3437 def test_export_tiny_llm_2_bypassed (self ):
3538 data = get_tiny_llm ()
3639 model , inputs = data ["model" ], data ["inputs" ]
40+ expected = model (** copy .deepcopy (inputs ))
3741 self .assertEqual (
3842 {"attention_mask" , "past_key_values" , "input_ids" , "position_ids" }, set (inputs )
3943 )
@@ -45,7 +49,7 @@ def test_export_tiny_llm_2_bypassed(self):
4549 for k in patched_DynamicCache ._PATCHES_ :
4650 self .assertEqual (getattr (patched_DynamicCache , k ), getattr (DynamicCache , k ))
4751
48- inputs = modificator (inputs )
52+ inputs = modificator (copy . deepcopy ( inputs ) )
4953
5054 def debug ():
5155 print ("***" , string_type (inputs , with_shape = True ))
@@ -67,7 +71,8 @@ def debug():
6771 ep = torch .export .export (
6872 model , (), kwargs = inputs , dynamic_shapes = data ["dynamic_shapes" ], strict = False
6973 )
70- assert ep
74+ got = ep .module ()(** inputs )
75+ self .assertEqualArrayAny (expected , got )
7176
7277
7378if __name__ == "__main__" :
0 commit comments