@@ -1291,36 +1291,41 @@ class MutableStateModule(torch.nn.Module):
12911291 def __init__ (self ):
12921292 super ().__init__ ()
12931293 self .register_buffer ("state" , torch .zeros (1 ))
1294+ self .register_buffer ("direct_copy_from_input" , torch .zeros (1 ))
12941295
12951296 def forward (self , x ):
12961297 y = x + self .state
12971298 self .state .add_ (1 )
1299+ self .direct_copy_from_input .copy_ (x )
12981300 return y
12991301
13001302 model = to_edge (export (MutableStateModule (), (torch .zeros (1 ),), strict = True ))
13011303 self .assertEqual (count_copies (model .exported_program ().graph_module ), 0 )
13021304 # Before
13031305 # graph():
1304- # %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
1305- # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1306- # %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1307- # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1308- # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1309- # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1310- # return (aten_add_tensor_1, aten_add_tensor)
1306+ # %b_state : [num_users=2] = placeholder[target=b_state]
1307+ # %b_direct_copy_from_input : [num_users=0] = placeholder[target=b_direct_copy_from_input]
1308+ # %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1309+ # %x : [num_users=2] = placeholder[target=x]
1310+ # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1311+ # %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1312+ # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1313+ # return (aten_add_tensor_1, x, aten_add_tensor)
13111314 gm , _ = insert_write_back_for_buffers_pass (model .exported_program ())
13121315
13131316 # After
13141317 # graph():
1315- # %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
1316- # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1317- # %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1318- # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1319- # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1320- # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1321- # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
1322- # return (copy__default, aten_add_tensor)
1323- self .assertEqual (count_copies (gm ), 1 )
1318+ # %b_state : [num_users=3] = placeholder[target=b_state]
1319+ # %b_direct_copy_from_input : [num_users=1] = placeholder[target=b_direct_copy_from_input]
1320+ # %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1321+ # %x : [num_users=2] = placeholder[target=x]
1322+ # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1323+ # %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1324+ # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1325+ # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_state, %aten_add_tensor_1), kwargs = {})
1326+ # %copy__default_1 : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_direct_copy_from_input, %x), kwargs = {})
1327+ # return (copy__default, copy__default_1, aten_add_tensor)
1328+ self .assertEqual (count_copies (gm ), 2 )
13241329
13251330 def test_remove_quantized_op_noop_pass (self ) -> None :
13261331 class TestAddSliceNoop (torch .nn .Module ):
0 commit comments