@@ -145,23 +145,23 @@ def test_phi2_export_interpreter(self):
145145 strict = False , # True works but then the it fails during the execution
146146 )
147147
148- # from experimental_experiment.torch_interpreter.tracing import CustomTracer
149- # CustomTracer.remove_unnecessary_slices(ep.graph)
150- memorize = []
148+ # from experimental_experiment.torch_interpreter.tracing import CustomTracer
149+ # CustomTracer.remove_unnecessary_slices(ep.graph)
150+ memorize = []
151151
152- class MyInterpreter (torch .fx .Interpreter ):
153- def call_function (self , target , args , kwargs ):
154- res = super ().call_function (target , args , kwargs )
155- memorize .append ((target , args , kwargs , res ))
156- return res
152+ class MyInterpreter (torch .fx .Interpreter ):
153+ def call_function (self , target , args , kwargs ):
154+ res = super ().call_function (target , args , kwargs )
155+ memorize .append ((target , args , kwargs , res ))
156+ return res
157157
158- inputs_copied = copy .deepcopy (inputs )
159- self .assertEqual (
160- str_inputs , string_type (inputs_copied , with_shape = True , with_min_max = True )
161- )
162- args , _spec = torch .utils ._pytree .tree_flatten (inputs_copied )
163- got = MyInterpreter (ep .module ()).run (* args )
164- self .assertEqualAny (expected , got )
158+ inputs_copied = copy .deepcopy (inputs )
159+ self .assertEqual (
160+ str_inputs , string_type (inputs_copied , with_shape = True , with_min_max = True )
161+ )
162+ args , _spec = torch .utils ._pytree .tree_flatten (inputs_copied )
163+ got = MyInterpreter (ep .module ()).run (* args )
164+ self .assertEqualAny (expected , got )
165165
166166
167167if __name__ == "__main__" :
0 commit comments