@@ -1719,6 +1719,83 @@ def forward(self, x):
17191719 self .assertEqual (external_map ["linear.weight" ], 0 )
17201720 self .assertEqual (external_map ["linear.bias" ], 1 )
17211721
1722+ def test_constant_tagged_tensor_dedup (self ) -> None :
1723+ class ConstantModule (nn .Module ):
1724+ def __init__ (self ):
1725+ super ().__init__ ()
1726+ constant = torch .tensor ([1.0 , 2.0 , 3.0 ])
1727+
1728+ # Register the same value with two different names as persistent buffers
1729+ self .register_buffer ("c0" , constant .clone (), persistent = True )
1730+ self .register_buffer ("c1" , constant .clone (), persistent = True )
1731+
1732+ def forward (self , x ):
1733+ return x + self .c0 + self .c1
1734+
1735+ model = to_edge (
1736+ export (ConstantModule (), (torch .ones (1 , 3 ),), strict = True )
1737+ ).to_executorch (
1738+ config = ExecutorchBackendConfig (
1739+ external_constants = True ,
1740+ )
1741+ )
1742+ emitter_output = model ._emitter_output
1743+ # constant_buffer is empty besides the non-constant placeholder 0.
1744+ self .assertEqual (len (emitter_output .program .constant_buffer ), 1 )
1745+ # only one item in the external constant buffer.
1746+ self .assertEqual (len (emitter_output .external_constant_buffer ), 1 )
1747+ # Setting external_constants=True, saves all constants to the key
1748+ # '_default_external_constant'.
1749+ external_map = emitter_output .external_constant_map [
1750+ "_default_external_constant"
1751+ ]
1752+ self .assertEqual (len (external_map ), 2 )
1753+ self .assertEqual (external_map ["c0" ], 0 )
1754+ self .assertEqual (external_map ["c1" ], 0 )
1755+
1756+ def test_constant_tagged_tensor_dedup_2 (self ) -> None :
1757+ class ConstantModule (nn .Module ):
1758+ def __init__ (self ):
1759+ super ().__init__ ()
1760+ constant0_4 = torch .tensor ([1.0 , 2.0 , 3.0 ])
1761+ constant4_5 = torch .tensor ([2.0 , 3.0 , 4.0 ])
1762+
1763+ # Register the same value with two different names as persistent buffers
1764+ self .register_buffer ("c0" , constant0_4 .clone (), persistent = True )
1765+ self .register_buffer ("c1" , constant0_4 .clone (), persistent = True )
1766+ self .register_buffer ("c2" , constant0_4 .clone (), persistent = True )
1767+ self .register_buffer ("c3" , constant0_4 .clone (), persistent = True )
1768+ self .register_buffer ("c4" , constant4_5 .clone (), persistent = True )
1769+ self .register_buffer ("c5" , constant4_5 .clone (), persistent = True )
1770+
1771+ def forward (self , x ):
1772+ return x + self .c0 + self .c1 + self .c2 + self .c3 + self .c4 + self .c5
1773+
1774+ model = to_edge (
1775+ export (ConstantModule (), (torch .ones (1 , 3 ),), strict = True )
1776+ ).to_executorch (
1777+ config = ExecutorchBackendConfig (
1778+ external_constants = True ,
1779+ )
1780+ )
1781+ emitter_output = model ._emitter_output
1782+ # constant_buffer is empty besides the non-constant placeholder 0.
1783+ self .assertEqual (len (emitter_output .program .constant_buffer ), 1 )
1784+ # Two items in the external constant buffer.
1785+ self .assertEqual (len (emitter_output .external_constant_buffer ), 2 )
1786+ # Setting external_constants=True, saves all constants to the key
1787+ # '_default_external_constant'.
1788+ external_map = emitter_output .external_constant_map [
1789+ "_default_external_constant"
1790+ ]
1791+ self .assertEqual (len (external_map ), 6 )
1792+ self .assertEqual (external_map ["c0" ], 0 )
1793+ self .assertEqual (external_map ["c1" ], 0 )
1794+ self .assertEqual (external_map ["c2" ], 0 )
1795+ self .assertEqual (external_map ["c3" ], 0 )
1796+ self .assertEqual (external_map ["c4" ], 1 )
1797+ self .assertEqual (external_map ["c5" ], 1 )
1798+
17221799 def test_delegate_deduplicate (self ) -> None :
17231800 class SharedModule (torch .nn .Module ):
17241801 def __init__ (self ):
0 commit comments