@@ -833,16 +833,7 @@ def __init__(
833833 )
834834 actual_num_slices = self .tree .nslices
835835
836- print ("\n --- Contraction Path Info ---" )
837- stats = self .tree .contract_stats ()
838- print (f"Path found with { actual_num_slices } slices." )
839- print (
840- f"Arithmetic Intensity (higher is better): { self .tree .arithmetic_intensity ():.2f} "
841- )
842- print ("flops (TFlops):" , stats ["flops" ] / 2 ** 40 / self .num_devices )
843- print ("write (GB):" , stats ["write" ] / 2 ** 27 / actual_num_slices )
844- print ("size (GB):" , stats ["size" ] / 2 ** 27 )
845- print ("-----------------------------\n " )
836+ self ._report_tree_info ()
846837
847838 slices_per_device = int (np .ceil (actual_num_slices / self .num_devices ))
848839 padded_size = slices_per_device * self .num_devices
@@ -872,6 +863,19 @@ def __init__(
872863
873864 logger .info ("Initialization complete." )
874865
866+ def _report_tree_info (self ) -> None :
867+ print ("\n --- Contraction Path Info ---" )
868+ actual_num_slices = self .tree .nslices
869+ stats = self .tree .contract_stats ()
870+ print (f"Path found with { actual_num_slices } slices." )
871+ print (
872+ f"Arithmetic Intensity (higher is better): { self .tree .arithmetic_intensity ():.2f} "
873+ )
874+ print ("flops (TFlops):" , stats ["flops" ] / 2 ** 40 / self .num_devices )
875+ print ("write (GB):" , stats ["write" ] / 2 ** 27 / actual_num_slices )
876+ print ("size (GB):" , stats ["size" ] / 2 ** 27 )
877+ print ("-----------------------------\n " )
878+
875879 @staticmethod
876880 def _get_tree_data (
877881 nodes_fn : Callable [[Tensor ], List [Gate ]],
0 commit comments