Skip to content

Commit 1c5e67e

Browse files
committed
add onehot_d_tensor in quantum.py and apply.
1 parent 793876b commit 1c5e67e

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

tensorcircuit/basecircuit.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
sample2all,
2727
_infer_num_sites,
2828
_decode_basis_label,
29+
onehot_d_tensor,
2930
)
3031
from .abstractcircuit import AbstractCircuit
3132
from .cons import npdtype, backend, dtypestr, contractor, rdtypestr
@@ -410,8 +411,7 @@ def measure_jit(
410411
np.array([1, 0])
411412
) + sample[i] * gates.array_to_tensor(np.array([0, 1]))
412413
else:
413-
vec = backend.one_hot(backend.cast(sample[i], "int32"), self._d)
414-
m = backend.cast(vec, dtypestr)
414+
m = onehot_d_tensor(sample[i], d=self._d)
415415
g1 = Gate(m)
416416
g1.id = id(g1)
417417
g1.is_dagger = False
@@ -507,29 +507,21 @@ def amplitude_before(self, l: Union[str, Tensor]) -> List[Gate]:
507507
:rtype: List[Gate]
508508
"""
509509

510-
def _basis_nod(_k: int) -> Tensor:
511-
_vec = np.zeros((self._d,), dtype=npdtype)
512-
_vec[_k] = 1.0
513-
return _vec
514-
515510
no, d_edges = self._copy()
516511
ms = []
517512
if self.is_dm:
518513
msconj = []
519514
if isinstance(l, str):
520515
symbols = _decode_basis_label(l, n=self._nqubits, dim=self._d)
521516
for k in symbols:
522-
n = _basis_nod(k)
517+
n = onehot_d_tensor(k, d=self._d)
523518
ms.append(tn.Node(n))
524519
if self.is_dm:
525520
msconj.append(tn.Node(n))
526521
else:
527522
l = backend.cast(l, dtype=dtypestr)
528523
for i in range(self._nqubits):
529-
endn = backend.cast(
530-
backend.one_hot(backend.cast(l[i], "int32"), self._d),
531-
dtype=dtypestr,
532-
)
524+
endn = onehot_d_tensor(l[i], d=self._d)
533525
ms.append(tn.Node(endn))
534526
if self.is_dm:
535527
msconj.append(tn.Node(endn))
@@ -1040,18 +1032,13 @@ def projected_subsystem(self, traceout: Tensor, left: Tuple[int, ...]) -> Tensor
10401032
:rtype: Tensor
10411033
"""
10421034

1043-
def _basis_gate(k_tensor: Any) -> Gate:
1044-
vec = backend.one_hot(backend.cast(k_tensor, "int32"), self._d)
1045-
vec = backend.cast(vec, dtypestr)
1046-
return Gate(vec)
1047-
10481035
traceout = backend.cast(traceout, dtypestr)
10491036
nodes, front = self._copy()
10501037
L = self._nqubits
10511038
edges = []
10521039
for i in range(len(traceout)):
10531040
if i not in left:
1054-
n = _basis_gate(traceout[i])
1041+
n = Gate(onehot_d_tensor(traceout[i], d=self._d))
10551042
nodes.append(n)
10561043
front[i] ^ n[0]
10571044
else:
@@ -1060,7 +1047,7 @@ def _basis_gate(k_tensor: Any) -> Gate:
10601047
if self.is_dm:
10611048
for i in range(len(traceout)):
10621049
if i not in left:
1063-
n = _basis_gate(traceout[i])
1050+
n = Gate(onehot_d_tensor(traceout[i], d=self._d))
10641051
nodes.append(n)
10651052
front[i + L] ^ n[0]
10661053
else:

tensorcircuit/quantum.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,44 @@ def get_all_nodes(edges: Iterable[Edge]) -> List[Node]:
5757
return nodes
5858

5959

60+
def onehot_d_tensor(_k: Union[int, Tensor], d: int = 2) -> Tensor:
61+
"""
62+
Construct a one-hot vector (or matrix) of local dimension ``d``.
63+
64+
:param _k: index or indices to set as 1. Can be an int or a backend Tensor.
65+
:type _k: int or Tensor
66+
:param d: local dimension (number of categories), defaults to 2
67+
:type d: int, optional
68+
:return: one-hot encoded vector (shape [d]) or matrix (shape [len(_k), d])
69+
:rtype: Tensor
70+
"""
71+
if isinstance(_k, int):
72+
vec = backend.one_hot(_k, d)
73+
else:
74+
vec = backend.one_hot(backend.cast(_k, "int32"), d)
75+
return backend.cast(vec, dtypestr)
76+
77+
6078
def _decode_basis_label(label: str, n: int, dim: int) -> List[int]:
79+
"""
80+
Decode a string basis label into a list of integer digits.
81+
82+
The label is interpreted in base-``dim`` using characters ``0–9A–Z``.
83+
Only dimensions up to 36 are supported.
84+
85+
:param label: basis label string, e.g. "010" or "A9F"
86+
:type label: str
87+
:param n: number of sites (expected length of the label)
88+
:type n: int
89+
:param dim: local dimension (2 <= dim <= 36)
90+
:type dim: int
91+
:return: list of integer digits of length ``n``, each in ``[0, dim-1]``
92+
:rtype: List[int]
93+
94+
:raises NotImplementedError: if ``dim > 36``
95+
:raises ValueError: if the label length mismatches ``n``,
96+
or contains invalid/out-of-range characters
97+
"""
6198
if dim > 36:
6299
raise NotImplementedError(
63100
f"String basis label supports d<=36 (0–9A–Z). Got dim={dim}. "

0 commit comments

Comments
 (0)