@@ -128,7 +128,7 @@ def test_torchbind_hop_schema(self):
128128 schema = CallTorchBind .schema (foo_ir , "add" )
129129 self .assertEqual (
130130 str (schema ),
131- "call_torchbind(__torch__.torch.classes._TorchScriptTesting._Foo _0 , str method, int _1) -> int _0" ,
131+ "call_torchbind(__torch__.torch.classes._TorchScriptTesting._Foo obj , str method, int _1) -> int _0" ,
132132 )
133133
134134 def test_torchbind_config_not_generated (self ):
@@ -146,7 +146,7 @@ def test_torchbind_hop_schema_no_input(self):
146146 schema = CallTorchBind .schema (q_ir , "pop" )
147147 self .assertEqual (
148148 str (schema ),
149- "call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0 , str method) -> Tensor _0" ,
149+ "call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue obj , str method) -> Tensor _0" ,
150150 )
151151
152152 def test_torchbind_hop_schema_no_output (self ):
@@ -155,7 +155,7 @@ def test_torchbind_hop_schema_no_output(self):
155155 schema = CallTorchBind .schema (q_ir , "push" )
156156 self .assertEqual (
157157 str (schema ),
158- "call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0 , str method, Tensor _1) -> NoneType _0" ,
158+ "call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue obj , str method, Tensor _1) -> NoneType _0" ,
159159 )
160160
161161 def test_torchbind_aot_compile (self ):
@@ -250,7 +250,7 @@ def test_torchbind_aot_compile(self):
250250 "target" : "call_torchbind" ,
251251 "inputs" : [
252252 {
253- "name" : "_0 " ,
253+ "name" : "obj " ,
254254 "arg" : {
255255 "as_custom_obj" : {
256256 "name" : "_torchbind_obj0" ,
@@ -293,20 +293,15 @@ def test_torchbind_aot_compile(self):
293293 self .assertTrue ((tmp_path_model / "custom_objs_config.json" ).exists ())
294294 self .assertTrue ((tmp_path_constants / "custom_obj_0" ).exists ())
295295
296- def test_torchbind_aoti (self ):
297- ep , inputs , orig_res , _ = self .get_exported_model ()
298- pt2_path = torch ._inductor .aoti_compile_and_package (ep )
299- optimized = torch ._inductor .aoti_load_package (pt2_path )
300- result = optimized (* inputs )
301- self .assertEqual (result , orig_res )
296+ # TODO: add accuracy test after we support loading and running compiled models with
297+ # torchbind objects.
302298
303299 @torch ._inductor .config .patch ("aot_inductor.use_runtime_constant_folding" , True )
304300 def test_torchbind_aot_compile_constant_folding (self ):
305- ep , inputs , orig_res , _ = self .get_exported_model ()
306- pt2_path = torch ._inductor .aoti_compile_and_package (ep )
307- optimized = torch ._inductor .aoti_load_package (pt2_path )
308- result = optimized (* inputs )
309- self .assertEqual (result , orig_res )
301+ ep , inputs , _ , _ = self .get_exported_model ()
302+ aot_compile (ep .module (), inputs , options = {"aot_inductor.package" : True })
303+ # TODO: add accuracy test after we support loading and running compiled models with
304+ # torchbind objects.
310305
311306 def test_torchbind_list_return_aot_compile (self ):
312307 class M (torch .nn .Module ):
@@ -322,48 +317,15 @@ def forward(self, x):
322317
323318 m = M ()
324319 inputs = (torch .ones (2 , 3 ),)
325- orig_res = m (* inputs )
326320
327321 # We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
328322 with enable_torchbind_tracing ():
329323 ep = torch .export .export (m , inputs , strict = False )
330324
331- pt2_path = torch ._inductor .aoti_compile_and_package (ep )
332- optimized = torch ._inductor .aoti_load_package (pt2_path )
333- result = optimized (* inputs )
334- self .assertEqual (result , orig_res )
335-
336- def test_torchbind_queue (self ):
337- class Foo (torch .nn .Module ):
338- def __init__ (self , tq ) -> None :
339- super ().__init__ ()
340- self .tq = tq
341-
342- def forward (self , x ):
343- self .tq .push (x .cos ())
344- self .tq .push (x .sin ())
345- # TODO: int return type in fallback kernel not support yet
346- x_cos = self .tq .pop () # + self.tq.size()
347- x_sin = self .tq .pop () # - self.tq.size()
348- return x_sin , x_cos
349-
350- inputs = (torch .randn (3 , 2 ),)
351-
352- q = _empty_tensor_queue ()
353- m = Foo (q )
354- orig_res = m (* inputs )
355-
356- q2 = _empty_tensor_queue ()
357- m2 = Foo (q2 )
358-
359- # We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
360- with enable_torchbind_tracing ():
361- ep = torch .export .export (m2 , inputs , strict = False )
325+ aot_compile (ep .module (), inputs , options = {"aot_inductor.package" : True })
362326
363- pt2_path = torch ._inductor .aoti_compile_and_package (ep )
364- optimized = torch ._inductor .aoti_load_package (pt2_path )
365- result = optimized (* inputs )
366- self .assertEqual (result , orig_res )
327+ # TODO: add accuracy test after we support loading and running compiled models with
328+ # torchbind objects.
367329
368330 @requires_gpu ()
369331 @torch ._dynamo .config .patch ("capture_dynamic_output_shape_ops" , True )
0 commit comments