@@ -737,7 +737,12 @@ def __init__(
737737
738738 logger .info ("Initialization complete." )
739739
740- def _get_single_slice_contraction_fn (self ) -> Callable [[Any , Tensor , int ], Tensor ]:
740+ def _get_single_slice_contraction_fn (
741+ self , op : Optional [Callable [[Tensor ], Tensor ]] = None
742+ ) -> Callable [[Any , Tensor , int ], Tensor ]:
743+ if op is None :
744+ op = backend .sum
745+
741746 def single_slice_contraction (
742747 tree : ctg .ContractionTree , params : Tensor , slice_idx : int
743748 ) -> Tensor :
@@ -746,16 +751,25 @@ def single_slice_contraction(
746751 input_arrays = [node .tensor for node in standardized_nodes ]
747752 sliced_arrays = tree .slice_arrays (input_arrays , slice_idx )
748753 result = tree .contract_core (sliced_arrays , backend = self ._backend )
749- return backend . sum ( backend . real ( result ) )
754+ return op ( result )
750755
751756 return single_slice_contraction
752757
753758 def _get_device_sum_vg_fn (
754759 self ,
760+ op : Optional [Callable [[Tensor ], Tensor ]] = None ,
761+ output_dtype : Optional [str ] = None ,
755762 ) -> Callable [[Any , Tensor , Tensor ], Tuple [Tensor , Tensor ]]:
756- base_fn = self ._get_single_slice_contraction_fn ()
763+ post_processing = lambda x : backend .real (backend .sum (x ))
764+ if op is None :
765+ op = post_processing
766+ base_fn = self ._get_single_slice_contraction_fn (op = op )
767+ # to ensure the output is real so that can be differentiated
757768 single_slice_vg_fn = jaxlib .value_and_grad (base_fn , argnums = 1 )
758769
770+ if output_dtype is None :
771+ output_dtype = rdtypestr
772+
759773 def device_sum_fn (
760774 tree : ctg .ContractionTree , params : Tensor , slice_indices_for_device : Tensor
761775 ) -> Tuple [Tensor , Tensor ]:
@@ -785,7 +799,7 @@ def do_nothing() -> Tuple[Tensor, Tensor]:
785799 )
786800
787801 initial_carry = (
788- backend .cast (backend .convert_to_tensor (0.0 ), dtype = rdtypestr ),
802+ backend .cast (backend .convert_to_tensor (0.0 ), dtype = output_dtype ),
789803 jaxlib .tree_util .tree_map (lambda x : jaxlib .numpy .zeros_like (x ), params ),
790804 )
791805 (final_value , final_grads ), _ = jaxlib .lax .scan (
@@ -795,21 +809,14 @@ def do_nothing() -> Tuple[Tensor, Tensor]:
795809
796810 return device_sum_fn
797811
798- def _compile_value_and_grad (self ) -> None :
799- if self ._compiled_vg_fn is not None :
800- return
801- device_sum_fn = self ._get_device_sum_vg_fn ()
802- # `tree` is arg 0, `params` is arg 1, `indices` is arg 2
803- # `tree` is static and broadcast to all devices
804- self ._compiled_vg_fn = jaxlib .pmap (
805- device_sum_fn ,
806- in_axes = (None , None , 0 ), # tree: broadcast, params: broadcast, indices: map
807- static_broadcasted_argnums = (0 ,), # arg 0 (tree) is a static argument
808- devices = self .devices ,
809- )
810-
811- def _get_device_sum_v_fn (self ) -> Callable [[Any , Tensor , Tensor ], Tensor ]:
812- base_fn = self ._get_single_slice_contraction_fn ()
812+ def _get_device_sum_v_fn (
813+ self ,
814+ op : Optional [Callable [[Tensor ], Tensor ]] = None ,
815+ output_dtype : Optional [str ] = None ,
816+ ) -> Callable [[Any , Tensor , Tensor ], Tensor ]:
817+ base_fn = self ._get_single_slice_contraction_fn (op = op )
818+ if output_dtype is None :
819+ output_dtype = dtypestr
813820
814821 def device_sum_fn (
815822 tree : ctg .ContractionTree , params : Tensor , slice_indices_for_device : Tensor
@@ -828,7 +835,7 @@ def compute_and_add() -> Tensor:
828835 )
829836
830837 initial_carry = backend .cast (
831- backend .convert_to_tensor (0.0 ), dtype = rdtypestr
838+ backend .convert_to_tensor (0.0 ), dtype = output_dtype
832839 )
833840 final_value , _ = jaxlib .lax .scan (
834841 scan_body , initial_carry , slice_indices_for_device
@@ -837,22 +844,28 @@ def compute_and_add() -> Tensor:
837844
838845 return device_sum_fn
839846
840- def _compile_value (self ) -> None :
841- if self ._compiled_v_fn is not None :
842- return
843- device_sum_fn = self ._get_device_sum_v_fn ()
844- self ._compiled_v_fn = jaxlib .pmap (
845- device_sum_fn ,
846- in_axes = (None , None , 0 ),
847- static_broadcasted_argnums = (0 ,),
848- devices = self .devices ,
849- )
850-
851847 # --- Public API ---
852848 def value_and_grad (
853- self , params : Tensor , aggregate : bool = True
849+ self ,
850+ params : Tensor ,
851+ aggregate : bool = True ,
852+ op : Optional [Callable [[Tensor ], Tensor ]] = None ,
853+ output_dtype : Optional [str ] = None ,
854854 ) -> Tuple [Tensor , Tensor ]:
855- self ._compile_value_and_grad ()
855+ if self ._compiled_vg_fn is None :
856+ device_sum_fn = self ._get_device_sum_vg_fn (op = op , output_dtype = output_dtype )
857+ # `tree` is arg 0, `params` is arg 1, `indices` is arg 2
858+ # `tree` is static and broadcast to all devices
859+ self ._compiled_vg_fn = jaxlib .pmap (
860+ device_sum_fn ,
861+ in_axes = (
862+ None ,
863+ None ,
864+ 0 ,
865+ ), # tree: broadcast, params: broadcast, indices: map
866+ static_broadcasted_argnums = (0 ,), # arg 0 (tree) is a static argument
867+ devices = self .devices ,
868+ )
856869 # Pass `self.tree` as the first argument
857870 device_values , device_grads = self ._compiled_vg_fn ( # type: ignore
858871 self .tree , params , self .batched_slice_indices
@@ -865,15 +878,36 @@ def value_and_grad(
865878 return total_value , total_grad
866879 return device_values , device_grads
867880
868- def value (self , params : Tensor , aggregate : bool = True ) -> Tensor :
869- self ._compile_value ()
881+ def value (
882+ self ,
883+ params : Tensor ,
884+ aggregate : bool = True ,
885+ op : Optional [Callable [[Tensor ], Tensor ]] = None ,
886+ output_dtype : Optional [str ] = None ,
887+ ) -> Tensor :
888+ if self ._compiled_v_fn is None :
889+ device_sum_fn = self ._get_device_sum_v_fn (op = op , output_dtype = output_dtype )
890+ self ._compiled_v_fn = jaxlib .pmap (
891+ device_sum_fn ,
892+ in_axes = (None , None , 0 ),
893+ static_broadcasted_argnums = (0 ,),
894+ devices = self .devices ,
895+ )
870896 device_values = self ._compiled_v_fn ( # type: ignore
871897 self .tree , params , self .batched_slice_indices
872898 )
873899 if aggregate :
874900 return backend .sum (device_values )
875901 return device_values
876902
877- def grad (self , params : Tensor , aggregate : bool = True ) -> Tensor :
878- _ , grad = self .value_and_grad (params , aggregate = aggregate )
903+ def grad (
904+ self ,
905+ params : Tensor ,
906+ aggregate : bool = True ,
907+ op : Optional [Callable [[Tensor ], Tensor ]] = None ,
908+ output_dtype : Optional [str ] = None ,
909+ ) -> Tensor :
910+ _ , grad = self .value_and_grad (
911+ params , aggregate = aggregate , op = op , output_dtype = output_dtype
912+ )
879913 return grad
0 commit comments