Skip to content

Commit 3be3f33

Browse files
implement correct function cache for distributedcontractor
1 parent 89ee53e commit 3be3f33

File tree

3 files changed

+93
-26
lines changed

3 files changed

+93
-26
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
# Change Log
22

33
## Unreleased
4+
5+
### Added
6+
47
- Add `Lattice` module (`tensorcircuit.templates.lattice`) for creating and manipulating various lattice geometries, including `SquareLattice`, `HoneycombLattice`, and `CustomizeLattice`.
58

9+
- Add `DistributedContractor` in experimental module with new examples for fast implementation of distribution circuit simulation on jax backend.
10+
11+
- Add `circuit.amplitude_before()` method to return the corresponding tensornetwork nodes.
12+
13+
### Fixed
14+
15+
- Fix the nodes order in contraction by giving each node a global `_stable_id_`.
16+
617
## v1.2.1
718

819
### Fixed

tensorcircuit/experimental.py

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,14 @@ def __init__(
690690

691691
self._params_template = params
692692
self._backend = "jax"
693+
self._compiled_v_fns: Dict[
694+
Tuple[Callable[[Tensor], Tensor], str],
695+
Callable[[Any, Tensor, Tensor], Tensor],
696+
] = {}
697+
self._compiled_vg_fns: Dict[
698+
Tuple[Callable[[Tensor], Tensor], str],
699+
Callable[[Any, Tensor, Tensor], Tensor],
700+
] = {}
693701

694702
logger.info("Running cotengra pathfinder... (This may take a while)")
695703
nodes = self.nodes_fn(self._params_template)
@@ -844,20 +852,29 @@ def compute_and_add() -> Tensor:
844852

845853
return device_sum_fn
846854

847-
# --- Public API ---
848-
def value_and_grad(
855+
def _get_or_compile_fn(
849856
self,
850-
params: Tensor,
851-
aggregate: bool = True,
852-
op: Optional[Callable[[Tensor], Tensor]] = None,
853-
output_dtype: Optional[str] = None,
854-
) -> Tuple[Tensor, Tensor]:
855-
if self._compiled_vg_fn is None:
856-
device_sum_fn = self._get_device_sum_vg_fn(op=op, output_dtype=output_dtype)
857-
# `tree` is arg 0, `params` is arg 1, `indices` is arg 2
858-
# `tree` is static and broadcast to all devices
859-
self._compiled_vg_fn = jaxlib.pmap(
860-
device_sum_fn,
857+
cache: Dict[
858+
Tuple[Callable[[Tensor], Tensor], str],
859+
Callable[[Any, Tensor, Tensor], Tensor],
860+
],
861+
fn_getter: Callable[..., Any],
862+
op: Optional[Callable[[Tensor], Tensor]],
863+
output_dtype: Optional[str],
864+
) -> Callable[[Any, Tensor, Tensor], Tensor]:
865+
"""
866+
Gets a compiled pmap-ed function from cache or compiles and caches it.
867+
868+
The cache key is a tuple of (op, output_dtype). Caution on lambda function!
869+
870+
Returns:
871+
The compiled, pmap-ed JAX function.
872+
"""
873+
cache_key = (op, output_dtype)
874+
if cache_key not in cache:
875+
device_fn = fn_getter(op=op, output_dtype=output_dtype)
876+
compiled_fn = jaxlib.pmap(
877+
device_fn,
861878
in_axes=(
862879
None,
863880
None,
@@ -866,10 +883,39 @@ def value_and_grad(
866883
static_broadcasted_argnums=(0,), # arg 0 (tree) is a static argument
867884
devices=self.devices,
868885
)
869-
# Pass `self.tree` as the first argument
870-
device_values, device_grads = self._compiled_vg_fn( # type: ignore
886+
cache[cache_key] = compiled_fn # type: ignore
887+
return cache[cache_key] # type: ignore
888+
889+
def value_and_grad(
890+
self,
891+
params: Tensor,
892+
aggregate: bool = True,
893+
op: Optional[Callable[[Tensor], Tensor]] = None,
894+
output_dtype: Optional[str] = None,
895+
) -> Tuple[Tensor, Tensor]:
896+
"""
897+
Calculates the value and gradient, compiling the pmap function if needed for the first call.
898+
899+
:param params: Parameters for the `nodes_fn` input
900+
:type params: Tensor
901+
:param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
902+
:type aggregate: bool, optional
903+
:param op: Optional post-processing function for the output, defaults to None (corresponding to `backend.real`)
904+
:type op: Optional[Callable[[Tensor], Tensor]], optional
905+
:param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `rdtypestr`)
906+
:type output_dtype: Optional[str], optional
907+
"""
908+
compiled_vg_fn = self._get_or_compile_fn(
909+
cache=self._compiled_vg_fns,
910+
fn_getter=self._get_device_sum_vg_fn,
911+
op=op,
912+
output_dtype=output_dtype,
913+
)
914+
915+
device_values, device_grads = compiled_vg_fn(
871916
self.tree, params, self.batched_slice_indices
872917
)
918+
873919
if aggregate:
874920
total_value = backend.sum(device_values)
875921
total_grad = jaxlib.tree_util.tree_map(
@@ -885,17 +931,27 @@ def value(
885931
op: Optional[Callable[[Tensor], Tensor]] = None,
886932
output_dtype: Optional[str] = None,
887933
) -> Tensor:
888-
if self._compiled_v_fn is None:
889-
device_sum_fn = self._get_device_sum_v_fn(op=op, output_dtype=output_dtype)
890-
self._compiled_v_fn = jaxlib.pmap(
891-
device_sum_fn,
892-
in_axes=(None, None, 0),
893-
static_broadcasted_argnums=(0,),
894-
devices=self.devices,
895-
)
896-
device_values = self._compiled_v_fn( # type: ignore
897-
self.tree, params, self.batched_slice_indices
934+
"""
935+
Calculates the value, compiling the pmap function for the first call.
936+
937+
:param params: Parameters for the `nodes_fn` input
938+
:type params: Tensor
939+
:param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
940+
:type aggregate: bool, optional
941+
:param op: Optional post-processing function for the output, defaults to None (corresponding to identity)
942+
:type op: Optional[Callable[[Tensor], Tensor]], optional
943+
:param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `dtypestr`)
944+
:type output_dtype: Optional[str], optional
945+
"""
946+
compiled_v_fn = self._get_or_compile_fn(
947+
cache=self._compiled_v_fns,
948+
fn_getter=self._get_device_sum_v_fn,
949+
op=op,
950+
output_dtype=output_dtype,
898951
)
952+
953+
device_values = compiled_v_fn(self.tree, params, self.batched_slice_indices)
954+
899955
if aggregate:
900956
return backend.sum(device_values)
901957
return device_values

tests/test_stabilizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_depolarize():
184184
c.depolarizing(0, 1, p=0.2)
185185
c.h(0)
186186
r.append(c.expectation_ps(z=[0]))
187-
assert 5 < np.sum(r) < 38
187+
assert 4 < np.sum(r) < 39
188188

189189

190190
def test_tableau_inputs():

0 commit comments

Comments
 (0)