@@ -690,6 +690,14 @@ def __init__(
690690
691691 self ._params_template = params
692692 self ._backend = "jax"
693+ self ._compiled_v_fns : Dict [
694+ Tuple [Callable [[Tensor ], Tensor ], str ],
695+ Callable [[Any , Tensor , Tensor ], Tensor ],
696+ ] = {}
697+ self ._compiled_vg_fns : Dict [
698+ Tuple [Callable [[Tensor ], Tensor ], str ],
699+ Callable [[Any , Tensor , Tensor ], Tensor ],
700+ ] = {}
693701
694702 logger .info ("Running cotengra pathfinder... (This may take a while)" )
695703 nodes = self .nodes_fn (self ._params_template )
@@ -844,20 +852,29 @@ def compute_and_add() -> Tensor:
844852
845853 return device_sum_fn
846854
847- # --- Public API ---
848- def value_and_grad (
855+ def _get_or_compile_fn (
849856 self ,
850- params : Tensor ,
851- aggregate : bool = True ,
852- op : Optional [Callable [[Tensor ], Tensor ]] = None ,
853- output_dtype : Optional [str ] = None ,
854- ) -> Tuple [Tensor , Tensor ]:
855- if self ._compiled_vg_fn is None :
856- device_sum_fn = self ._get_device_sum_vg_fn (op = op , output_dtype = output_dtype )
857- # `tree` is arg 0, `params` is arg 1, `indices` is arg 2
858- # `tree` is static and broadcast to all devices
859- self ._compiled_vg_fn = jaxlib .pmap (
860- device_sum_fn ,
857+ cache : Dict [
858+ Tuple [Callable [[Tensor ], Tensor ], str ],
859+ Callable [[Any , Tensor , Tensor ], Tensor ],
860+ ],
861+ fn_getter : Callable [..., Any ],
862+ op : Optional [Callable [[Tensor ], Tensor ]],
863+ output_dtype : Optional [str ],
864+ ) -> Callable [[Any , Tensor , Tensor ], Tensor ]:
865+ """
866+ Gets a compiled pmap-ed function from cache or compiles and caches it.
867+
868+ The cache key is a tuple of (op, output_dtype). Caution on lambda function!
869+
870+ Returns:
871+ The compiled, pmap-ed JAX function.
872+ """
873+ cache_key = (op , output_dtype )
874+ if cache_key not in cache :
875+ device_fn = fn_getter (op = op , output_dtype = output_dtype )
876+ compiled_fn = jaxlib .pmap (
877+ device_fn ,
861878 in_axes = (
862879 None ,
863880 None ,
@@ -866,10 +883,39 @@ def value_and_grad(
866883 static_broadcasted_argnums = (0 ,), # arg 0 (tree) is a static argument
867884 devices = self .devices ,
868885 )
869- # Pass `self.tree` as the first argument
870- device_values , device_grads = self ._compiled_vg_fn ( # type: ignore
886+ cache [cache_key ] = compiled_fn # type: ignore
887+ return cache [cache_key ] # type: ignore
888+
889+ def value_and_grad (
890+ self ,
891+ params : Tensor ,
892+ aggregate : bool = True ,
893+ op : Optional [Callable [[Tensor ], Tensor ]] = None ,
894+ output_dtype : Optional [str ] = None ,
895+ ) -> Tuple [Tensor , Tensor ]:
896+ """
897+ Calculates the value and gradient, compiling the pmap function if needed for the first call.
898+
899+ :param params: Parameters for the `nodes_fn` input
900+ :type params: Tensor
901+ :param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
902+ :type aggregate: bool, optional
903+ :param op: Optional post-processing function for the output, defaults to None (corresponding to `backend.real`)
904+ :type op: Optional[Callable[[Tensor], Tensor]], optional
905+ :param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `rdtypestr`)
906+ :type output_dtype: Optional[str], optional
907+ """
908+ compiled_vg_fn = self ._get_or_compile_fn (
909+ cache = self ._compiled_vg_fns ,
910+ fn_getter = self ._get_device_sum_vg_fn ,
911+ op = op ,
912+ output_dtype = output_dtype ,
913+ )
914+
915+ device_values , device_grads = compiled_vg_fn (
871916 self .tree , params , self .batched_slice_indices
872917 )
918+
873919 if aggregate :
874920 total_value = backend .sum (device_values )
875921 total_grad = jaxlib .tree_util .tree_map (
@@ -885,17 +931,27 @@ def value(
885931 op : Optional [Callable [[Tensor ], Tensor ]] = None ,
886932 output_dtype : Optional [str ] = None ,
887933 ) -> Tensor :
888- if self ._compiled_v_fn is None :
889- device_sum_fn = self ._get_device_sum_v_fn (op = op , output_dtype = output_dtype )
890- self ._compiled_v_fn = jaxlib .pmap (
891- device_sum_fn ,
892- in_axes = (None , None , 0 ),
893- static_broadcasted_argnums = (0 ,),
894- devices = self .devices ,
895- )
896- device_values = self ._compiled_v_fn ( # type: ignore
897- self .tree , params , self .batched_slice_indices
934+ """
935+ Calculates the value, compiling the pmap function for the first call.
936+
937+ :param params: Parameters for the `nodes_fn` input
938+ :type params: Tensor
939+ :param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
940+ :type aggregate: bool, optional
941+ :param op: Optional post-processing function for the output, defaults to None (corresponding to identity)
942+ :type op: Optional[Callable[[Tensor], Tensor]], optional
943+ :param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `dtypestr`)
944+ :type output_dtype: Optional[str], optional
945+ """
946+ compiled_v_fn = self ._get_or_compile_fn (
947+ cache = self ._compiled_v_fns ,
948+ fn_getter = self ._get_device_sum_v_fn ,
949+ op = op ,
950+ output_dtype = output_dtype ,
898951 )
952+
953+ device_values = compiled_v_fn (self .tree , params , self .batched_slice_indices )
954+
899955 if aggregate :
900956 return backend .sum (device_values )
901957 return device_values
0 commit comments