@@ -3369,6 +3369,56 @@ def forward(self, q, k, v, attn_bias):
33693369 )
33703370 self .check_model (Model (), example_inputs )
33713371
3372+ def test_aoti_runtime_asserts (self ):
3373+ from torch ._dispatch .python import enable_python_dispatcher
3374+ from torch .export ._draft_export import draft_export
3375+
3376+ with torch .library ._scoped_library ("mylib" , "FRAGMENT" ) as lib :
3377+ torch .library .define (
3378+ "mylib::foo" ,
3379+ "(Tensor a, Tensor b) -> Tensor" ,
3380+ tags = torch .Tag .pt2_compliant_tag ,
3381+ lib = lib ,
3382+ )
3383+
3384+ @torch .library .impl ("mylib::foo" , "cpu" , lib = lib )
3385+ def foo (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
3386+ return a [: b .item ()]
3387+
3388+ @torch .library .impl_abstract ("mylib::foo" , lib = lib )
3389+ def foo_fake_impl (a , b ):
3390+ ctx = torch .library .get_ctx ()
3391+ u = ctx .new_dynamic_size ()
3392+ return torch .empty (u )
3393+
3394+ class M (torch .nn .Module ):
3395+ def forward (self , a , b ):
3396+ res = torch .ops .mylib .foo (a , b )
3397+ s = res .shape [0 ]
3398+ torch ._check (s > 3 )
3399+ torch ._check (s < a .shape [0 ])
3400+ return a [s - 3 ]
3401+
3402+ example_inputs = (torch .randn (100 ), torch .tensor (10 ))
3403+ ep = draft_export (M (), example_inputs )
3404+ m = ep .module ()
3405+ from torch .fx .passes .fake_tensor_prop import FakeTensorProp
3406+
3407+ example_inputs = [
3408+ node .meta ["val" ] for node in m .graph .nodes if node .op == "placeholder"
3409+ ]
3410+ fake_mode = example_inputs [0 ].fake_mode
3411+ with enable_python_dispatcher (), fake_mode :
3412+ FakeTensorProp (m , mode = fake_mode ).propagate_dont_convert_inputs (
3413+ * example_inputs
3414+ )
3415+
3416+ # TODO: change to the tests below after MetadataMismatchError is fixed
3417+ # pt2_file = torch._inductor.aoti_compile_and_package(ep)
3418+ # optimized = torch._inductor.aoti_load_package(pt2_file)
3419+
3420+ # self.assertTrue(same(optimized(example_inputs), m(example_inputs)))
3421+
33723422 def test_index_put_with_none_index (self ):
33733423 # index_put falls back in the deterministic mode
33743424 with DeterministicGuard (True ):
0 commit comments