Skip to content

Commit d1b8ba7

Browse files
committed
fix export
1 parent b9def44 commit d1b8ba7

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,23 +145,23 @@ def test_phi2_export_interpreter(self):
145145
strict=False, # True works but then the it fails during the execution
146146
)
147147

148-
# from experimental_experiment.torch_interpreter.tracing import CustomTracer
149-
# CustomTracer.remove_unnecessary_slices(ep.graph)
150-
memorize = []
148+
# from experimental_experiment.torch_interpreter.tracing import CustomTracer
149+
# CustomTracer.remove_unnecessary_slices(ep.graph)
150+
memorize = []
151151

152-
class MyInterpreter(torch.fx.Interpreter):
153-
def call_function(self, target, args, kwargs):
154-
res = super().call_function(target, args, kwargs)
155-
memorize.append((target, args, kwargs, res))
156-
return res
152+
class MyInterpreter(torch.fx.Interpreter):
153+
def call_function(self, target, args, kwargs):
154+
res = super().call_function(target, args, kwargs)
155+
memorize.append((target, args, kwargs, res))
156+
return res
157157

158-
inputs_copied = copy.deepcopy(inputs)
159-
self.assertEqual(
160-
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
161-
)
162-
args, _spec = torch.utils._pytree.tree_flatten(inputs_copied)
163-
got = MyInterpreter(ep.module()).run(*args)
164-
self.assertEqualAny(expected, got)
158+
inputs_copied = copy.deepcopy(inputs)
159+
self.assertEqual(
160+
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
161+
)
162+
args, _spec = torch.utils._pytree.tree_flatten(inputs_copied)
163+
got = MyInterpreter(ep.module()).run(*args)
164+
self.assertEqualAny(expected, got)
165165

166166

167167
if __name__ == "__main__":

0 commit comments

Comments
 (0)