Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tn_generative/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,15 @@ def get_ruby_vanderwaals(
"""Generates ruby rydberg physical system for `[size_x, size_y]` domain
with specification of detuning `delta` parameter."""
return physical_systems.RubyRydbergVanderwaals(size_x, size_y, delta)


@register_task('cluster_state')
def get_cluster_state(
size_x: int,
size_y: int,
onsite_z_field: float = 0.0,
) -> PhysicalSystem:
"""Generates cluster state for `[size_x, size_y]` domain with specification
of onsite z field coupling `onsite_z_field`.
"""
return physical_systems.ClusterState(size_x, size_y, onsite_z_field)
38 changes: 37 additions & 1 deletion tn_generative/mps_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,46 @@ def random_uniform_basis_sampler(
return base_sample_fn(sample_key, rotated_mps), basis


def xz_neel_basis_sampler(
key: jax.random.PRNGKeyArray,
mps: qtn.MatrixProductState,
neel_probabilities: Tuple[float, float],
base_sample_fn: SamplerFn = gibbs_sampler,
) -> MeasurementAndBasis:
"""Draws a sample from `mps` in alternating X/Z basis selected randomly.

Samples `mps` in an alternating X/Z basis which is selected randomly
with probabilities neel_probabilities.
[XZX..., ZXZ...] are mapped to [0, 1].

Args:
key: random key used to draw a sample.
mps: matrix product state in `z` basis from which to draw a sample.
neel_probabilities: probabilities for selecting XZX.../ZXZ... basis.
base_sample_fn: sampler method. Default is gibbs_sampler.

Returns:
Tuple of mesurement sample and basis.
"""
neel_probabilities = jnp.asarray(neel_probabilities)
sample_key, basis_key = jax.random.split(key, 2)
basis_val_start = jax.random.choice(basis_key, np.arange(2),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot - is np.arange(2) x and z? I though that 0, 1, 2 were mapped to x, y z basis.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall you probably want:

xz_basis = ...
zx_basis = ...
basis = jax.random.choice(basis_key, np.stack([xz_basis, zx_basis]))

p=neel_probabilities
)
basis = (jnp.arange(mps.L) + basis_val_start) % 2 * 2.
mps = mps.copy()
rotation_mpo = mps_utils.z_to_basis_mpo(basis)
rotated_mps = rotation_mpo.apply(mps)
return base_sample_fn(sample_key, rotated_mps), basis


register_sampler('x_basis_sampler')(
functools.partial(fixed_basis_sampler, basis=0))
register_sampler('z_basis_sampler')(
functools.partial(fixed_basis_sampler, basis=2))
register_sampler('xz_basis_sampler')(
functools.partial(random_uniform_basis_sampler,
x_y_z_probabilities=[0.5, 0.0, 0.5]))
x_y_z_probabilities=[0.5, 0.0, 0.5]
))
register_sampler('xz_neel_basis_sampler')(
functools.partial(xz_neel_basis_sampler, neel_probabilities=[0.5, 0.5]))
184 changes: 150 additions & 34 deletions tn_generative/physical_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import abc
from abc import abstractmethod
import itertools
import functools
from typing import Callable

import numpy as np
Expand Down Expand Up @@ -33,21 +32,21 @@ def get_ham(self) -> qtn.MatrixProductOperator:
"""Returns a hamiltonian MPO."""

def get_sparse_operator(
self,
self,
terms: types.TermsTuple,
) -> quimb_exp_op.SparseOperatorBuilder:
"""Generates operator including `terms` using sparse operator builder."""
if self.hilbert_space is None:
raise ValueError(
f'subclass {self.__name__} did not implement custom `hilbert_space`.'
)
)
sparse_operator = quimb_exp_op.SparseOperatorBuilder(
hilbert_space=self.hilbert_space
)
)
for term in terms: # add all terms to the operator.
sparse_operator += term
return sparse_operator

def get_ham_mpos(self) -> list[qtn.MatrixProductOperator]:
"""Returns MPOs for list of terms in the hamiltonian.

Expand All @@ -59,8 +58,8 @@ def get_ham_mpos(self) -> list[qtn.MatrixProductOperator]:
if self.get_terms() is None:
raise ValueError(
f'subclass {self.__name__} did not implement custom `get_terms`.'
f'subclass {self.__name__} should either implement custom'
'`get_ham_mpos` or provide `hilbert_space` and implement `get_terms`.'
f'subclass {self.__name__} should either implement custom'
'`get_ham_mpos` or provide `hilbert_space` and implement `get_terms`.'
)
mpos = []
terms = [(1., *term[1:]) for term in self.get_terms()]
Expand Down Expand Up @@ -247,8 +246,8 @@ def get_ham(self) -> qtn.MatrixProductOperator:
class RubyRydbergVanderwaals(PhysicalSystem): #TODO(YT): add tests.
"""Implementation for ruby Rydberg hamiltonian.

Note: this constructor assumes Van der Waals interactions among neibours.
The range of neibours are determined by Callable `nb_ratio_fn`, depending on
Note: this constructor assumes Van der Waals interactions among neibours.
The range of neibours are determined by Callable `nb_ratio_fn`, depending on
ruby lattice aspect ratio, `rho`specified in ascending order.

Args:
Expand All @@ -258,7 +257,7 @@ class RubyRydbergVanderwaals(PhysicalSystem): #TODO(YT): add tests.
rho: aspect ratio of the ruby lattice.
rb: Rydberg blockade radius, in units of lattice spacing.
omega: laser Rabi frequency, `x` field.
nb_ratio_fn: Callable that returns a tuple of ascending neibour radii.
nb_ratio_fn: Callable that returns a tuple of ascending neibour radii.

Returns:
Ruby Rydberg hamiltonian Physical system.
Expand All @@ -269,58 +268,60 @@ def __init__(
Lx: int,
Ly: int,
delta: float = 5.0,
rho: float = 3.,
rb: float = 3.8,
rho: float = 3.,
rb: float = 3.8,
omega: float = 1.,
nb_ratio_fn: Callable[[float], tuple[float, ...]] = lambda rho: (
1., rho, np.sqrt(1. + rho**2)
),
),
):
self.n_sites = int(Lx * Ly * 6)
self.Lx = Lx
self.Ly = Ly
self.delta = delta
self.a = 1. / 4. # lattice spacing.
self.omega = omega
self.rho = rho
self.rho = rho
self.epsilon = 1e-3
self.nb_radii = tuple(r * self.a + self.epsilon for r in nb_ratio_fn(self.rho))
self.nb_radii = tuple(
r * self.a + self.epsilon for r in nb_ratio_fn(self.rho)
)
self.vs = np.array([(rb / r)**6 for r in nb_ratio_fn(self.rho)])

self._lattice = self._get_expanded_lattice(
self.rho, self.Lx, self.Ly, self.a
) # COMMENT: I don't seem to need __post_init__ here.

@property
def hilbert_space(self) -> types.HilbertSpace:
return quimb_exp_op.HilbertSpace(self.n_sites)

def _get_expanded_lattice(
self,
rho: float,
Lx: int,
rho: float,
Lx: int,
Ly: int,
a: float,
a: float,
) -> lattices.Lattice:
"""Constructs lattice for rydberg Hamiltonian.
Args:
rho: aspect ratio of the ruby lattice.
Lx: number of unit cells in x direction.
Ly: number of unit cells in y direction.
a: lattice spacing.
Returns: Expanded lattice.

Returns: Expanded lattice.
"""
unit_cell_points = np.array(
[[1. / 4., 0.], [1./ 8., np.sqrt(3) / 8.],
[[1. / 4., 0.], [1./ 8., np.sqrt(3) / 8.],
[3. / 8., np.sqrt(3) / 8.], [1. / 8., np.sqrt(3) / 8. + a * rho],
[3. / 8., np.sqrt(3) / 8. + a * rho],
[3. / 8., np.sqrt(3) / 8. + a * rho],
[1. / 4., np.sqrt(3) / 4. + a * rho]]
)
)
unit_cell = lattices.Lattice(unit_cell_points)
a1 = np.array([2 * a * self.rho * np.sqrt(3) / 2. + a, 0.0])
a2 = np.array([
a * rho * np.sqrt(3) / 2. + a / 2.,
a * rho * np.sqrt(3) / 2. + a / 2.,
a * rho * 1. / 2. + a * np.sqrt(3) / 2 + a * rho
])
expanded_lattice = sum(
Expand All @@ -334,21 +335,21 @@ def _get_annulus_bonds(
nb_outer: float,
nb_inner: float = 0.,
) -> node_collections.NodesCollection:
"""Constructs `NodeCollection`s for bonds between an annulus of radius
"""Constructs `NodeCollection`s for bonds between an annulus of radius
`nb_outer` and `nb_inner` nearest neighbour in the PXP rydberg Hamiltonian.

Args:
nb_outer: radius of outer annulus.
nb_inner: radius of inner annulus.

Returns:
Bonds within an annulus of radius `nb_outer` and `nb_inner`.
Bonds within an annulus of radius `nb_outer` and `nb_inner`.
"""
nn_bonds = node_collections.get_nearest_neighbors(
self._lattice, nb_outer, nb_inner
)
return nn_bonds

def _get_nearest_neighbour_bonds(
self,
) -> list[node_collections.NodesCollection]:
Expand All @@ -362,13 +363,13 @@ def _get_nearest_neighbour_bonds(
raise ValueError(
f'`nb_radii` must be in ascending order. '
f'{self.nb_radii[i - 1]=}` is greater than {self.nb_radii[i]=}`.'
)
)
nn_bonds = self._get_annulus_bonds(
self.nb_radii[i], self.nb_radii[i - 1]
)
all_nn_bonds.append(nn_bonds)
return all_nn_bonds

def _get_nearest_neighbour_groups(self) -> list[types.TermsTuple]:
"""Constuct terms for nearest neighbour bonds between each annulus."""
all_nn_bonds = self._get_nearest_neighbour_bonds()
Expand All @@ -394,7 +395,7 @@ def _get_onsite_groups(self) -> list[types.TermsTuple]:
def _get_all_terms_groups(self) -> list[types.TermsTuple]:
"""Get all terms in hamiltonian as list of groups."""
return self._get_nearest_neighbour_groups() + self._get_onsite_groups()

def get_terms(self) -> types.TermsTuple:
"""Merge all terms from all groups into one tuple."""
all_terms_groups = self._get_all_terms_groups()
Expand All @@ -411,3 +412,118 @@ def get_ham(self) -> qtn.MatrixProductOperator:
self.get_sparse_operator(terms).build_mpo()
)
return sum(hamiltonian_mpo_groups[1:], start=hamiltonian_mpo_groups[0])


class ClusterState(PhysicalSystem):
"""Implementation for cluster state hamiltonian.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add one blank like separating one-line comment and verbose comment.

Note: this constructor assumes antiferromagnetic coupling.
`coupling_value == 1` corresponds to `H = -1 * (Σ S) ...`.
"""

def __init__(
self,
Lx: int,
Ly: int,
onsite_z_field: float = 0.,
coupling: float = 1.0,
):
self.n_sites = int(Lx * Ly)
self.Lx = Lx
self.Ly = Ly
self.coupling = coupling
self.onsite_z_field = onsite_z_field

@property
def hilbert_space(self) -> types.HilbertSpace:
return quimb_exp_op.HilbertSpace(self.n_sites)

def _get_expanded_lattice(self) -> lattices.Lattice:
"""Constructs lattice for cluster state.

Returns:
Expanded lattice.
"""
unit_cell_points = np.stack([np.array([0, 0])])
unit_cell = lattices.Lattice(unit_cell_points)
a1 = np.array([1.0, 0.0]) # unit vectors for square lattice.
a2 = np.array([0.0, 1.0])
expanded_lattice = sum(
unit_cell.shift(a1 * i + a2 * j)
for i, j in itertools.product(range(self.Lx), range(self.Ly))
)
return expanded_lattice

def _get_stabilizer_groups(self) -> list[node_collections.NodesCollection]:
"""Constructs `NodeCollection`s for all groups in cluster state Hamiltonian.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: make it 1-line or a full docstring

Copy link
Owner Author

@teng10 teng10 Aug 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is essentially one-line doc, you are saying the """ should be in the same line?
I was looking at quimb doc and it looks like they do this in general
"""Some doc
"""

"""
expanded_lattice = self._get_expanded_lattice()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to add self._lattice similar to the previous instance?

nn_bonds = node_collections.get_nearest_neighbors(
expanded_lattice, 1. + 1e-3
)
stabilizer_bonds = []
for i in range(self.n_sites):
stabilizer_bond = []
for bond in nn_bonds.nodes:
if i in bond:
stabilizer_bond.append(bond)
stabilizer_bonds.append(np.concatenate(stabilizer_bond))

lengths = set(
[len(stabilizer_bond) for stabilizer_bond in stabilizer_bonds]
)
stabilizer_bonds_groups = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per offline chat flagging that this is being reset inside of the loop and used outside later. This is probably still correct because you accumulate stabilizer_bonds that is outside of the loop, but might be worthwhile trying to refactor these steps to be a bit more constructive.

for length in sorted(lengths):
stabilizer_bonds_groups.append(
np.stack(
[b for b in stabilizer_bonds if len(b) == length]
)
)
stabilizer_nodes_groups = []
for stabilizer_bond_group in stabilizer_bonds_groups:
stabilizer_nodes = node_collections.NodesCollection(
stabilizer_bond_group, expanded_lattice
)
stabilizer_nodes_groups.append(stabilizer_nodes)
return stabilizer_nodes_groups

def _get_all_terms_groups(self) -> list[types.TermsTuple]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's discuss some of these construction methods in more details to see if we can build something a bit easier to parse.

"""Constructs all terms for all groups in cluster state Hamiltonian."""
def _sort_by_counts(nodes: np.ndarray) -> np.ndarray:
"""Helper function for sorting sites by their counts.
Args: nodes: array of nodes.
Returns: sorted array of nodes, with the last site most frequent.
"""
my_counts = np.apply_along_axis(
np.unique, axis=1, arr=nodes, return_counts=True
)
arg = np.argsort(my_counts[:, 1, :])
return np.take_along_axis(my_counts[:, 0, :], arg, axis=1)

stabilizer_nodes_groups = self._get_stabilizer_groups()
stabilizer_terms_groups = []
for stabilizer_nodes in stabilizer_nodes_groups:
sorted_nodes = _sort_by_counts(stabilizer_nodes.nodes)
stabilizer_terms = []
for node in sorted_nodes:
term = (-self.coupling, ('x', node[-1])) # most frequent site is X.
term += tuple(('z', i) for i in node[:-1])
stabilizer_terms.append(term)
stabilizer_terms_groups.append(stabilizer_terms)
return stabilizer_terms_groups

def get_terms(self) -> types.TermsTuple:
"""Merge all terms from all groups into one list."""
all_terms_groups = self._get_all_terms_groups()
all_terms = []
for group in all_terms_groups:
all_terms += group
if self.onsite_z_field != 0.:
all_terms += [
(-self.onsite_z_field, ('z', i)) for i in range(self.n_sites)
]
return all_terms

def get_ham(self) -> qtn.MatrixProductOperator:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use the default value here (since it's the same)

"""Get hamiltonian as MPO."""
ham_builder = self.get_sparse_operator(self.get_terms())
return ham_builder.build_mpo()