Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions qiskit/circuit/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: symbol}
self._name_map = None
self._qpy_replay = []
self._standalone_param = True

def assign(self, parameter, value):
if parameter != self:
Expand Down Expand Up @@ -172,3 +174,5 @@ def __setstate__(self, state):
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: self._symbol_expr}
self._name_map = None
self._qpy_replay = []
self._standalone_param = True
201 changes: 167 additions & 34 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"""

from __future__ import annotations

from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, Union

import numbers
Expand All @@ -30,12 +33,86 @@
ParameterValueType = Union["ParameterExpression", float]


class _OPCode(IntEnum):
ADD = 0
SUB = 1
MUL = 2
DIV = 3
POW = 4
SIN = 5
COS = 6
TAN = 7
ASIN = 8
ACOS = 9
EXP = 10
LOG = 11
SIGN = 12
GRAD = 13
CONJ = 14
SUBSTITUTE = 15
ABS = 16
ATAN = 17
RSUB = 18
RDIV = 19
RPOW = 20


_OP_CODE_MAP = (
"__add__",
"__sub__",
"__mul__",
"__truediv__",
"__pow__",
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"exp",
"log",
"sign",
"gradient",
"conjugate",
"subs",
"abs",
"arctan",
"__rsub__",
"__rtruediv__",
"__rpow__",
)


def op_code_to_method(op_code: _OPCode):
"""Return the method name for a given op_code."""
return _OP_CODE_MAP[op_code]


@dataclass
class _INSTRUCTION:
op: _OPCode
lhs: ParameterValueType | None
rhs: ParameterValueType | None = None


@dataclass
class _SUBS:
binds: dict
op: _OPCode = _OPCode.SUBSTITUTE


class ParameterExpression:
"""ParameterExpression class to enable creating expressions of Parameters."""

__slots__ = ["_parameter_symbols", "_parameter_keys", "_symbol_expr", "_name_map"]
__slots__ = [
"_parameter_symbols",
"_parameter_keys",
"_symbol_expr",
"_name_map",
"_qpy_replay",
"_standalone_param",
]

def __init__(self, symbol_map: dict, expr):
def __init__(self, symbol_map: dict, expr, *, _qpy_replay=None):
"""Create a new :class:`ParameterExpression`.

Not intended to be called directly, but to be instantiated via operations
Expand All @@ -54,6 +131,11 @@ def __init__(self, symbol_map: dict, expr):
self._parameter_keys = frozenset(p._hash_key() for p in self._parameter_symbols)
self._symbol_expr = expr
self._name_map: dict | None = None
self._standalone_param = False
if _qpy_replay is not None:
self._qpy_replay = _qpy_replay
else:
self._qpy_replay = []

@property
def parameters(self) -> set:
Expand All @@ -69,8 +151,14 @@ def _names(self) -> dict:

def conjugate(self) -> "ParameterExpression":
"""Return the conjugate."""
if self._standalone_param:
new_op = _INSTRUCTION(_OPCode.CONJ, self)
else:
new_op = _INSTRUCTION(_OPCode.CONJ, None)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
conjugated = ParameterExpression(
self._parameter_symbols, symengine.conjugate(self._symbol_expr)
self._parameter_symbols, symengine.conjugate(self._symbol_expr), _qpy_replay=new_replay
)
return conjugated

Expand Down Expand Up @@ -117,6 +205,7 @@ def bind(
self._raise_if_passed_unknown_parameters(parameter_values.keys())
self._raise_if_passed_nan(parameter_values)

new_op = _SUBS(parameter_values)
symbol_values = {}
for parameter, value in parameter_values.items():
if (param_expr := self._parameter_symbols.get(parameter)) is not None:
Expand All @@ -143,7 +232,12 @@ def bind(
f"(Expression: {self}, Bindings: {parameter_values})."
)

return ParameterExpression(free_parameter_symbols, bound_symbol_expr)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(
free_parameter_symbols, bound_symbol_expr, _qpy_replay=new_replay
)

def subs(
self, parameter_map: dict, allow_unknown_parameters: bool = False
Expand Down Expand Up @@ -175,6 +269,7 @@ def subs(
for p in replacement_expr.parameters
}
self._raise_if_parameter_names_conflict(inbound_names, parameter_map.keys())
new_op = _SUBS(parameter_map)

# Include existing parameters in self not set to be replaced.
new_parameter_symbols = {
Expand All @@ -192,8 +287,12 @@ def subs(
new_parameter_symbols[p] = symbol_type(p.name)

substituted_symbol_expr = self._symbol_expr.subs(symbol_map)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(new_parameter_symbols, substituted_symbol_expr)
return ParameterExpression(
new_parameter_symbols, substituted_symbol_expr, _qpy_replay=new_replay
)

def _raise_if_passed_unknown_parameters(self, parameters):
unknown_parameters = parameters - self.parameters
Expand Down Expand Up @@ -231,7 +330,11 @@ def _raise_if_parameter_names_conflict(self, inbound_parameters, outbound_parame
)

def _apply_operation(
self, operation: Callable, other: ParameterValueType, reflected: bool = False
self,
operation: Callable,
other: ParameterValueType,
reflected: bool = False,
op_code: _OPCode = None,
) -> "ParameterExpression":
"""Base method implementing math operations between Parameters and
either a constant or a second ParameterExpression.
Expand All @@ -253,7 +356,6 @@ def _apply_operation(
A new expression describing the result of the operation.
"""
self_expr = self._symbol_expr

if isinstance(other, ParameterExpression):
self._raise_if_parameter_names_conflict(other._names)
parameter_symbols = {**self._parameter_symbols, **other._parameter_symbols}
Expand All @@ -266,10 +368,26 @@ def _apply_operation(

if reflected:
expr = operation(other_expr, self_expr)
if op_code in {_OPCode.RSUB, _OPCode.RDIV, _OPCode.RPOW}:
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self, other)
else:
new_op = _INSTRUCTION(op_code, None, other)
else:
if self._standalone_param:
new_op = _INSTRUCTION(op_code, other, self)
else:
new_op = _INSTRUCTION(op_code, other, None)
else:
expr = operation(self_expr, other_expr)

out_expr = ParameterExpression(parameter_symbols, expr)
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self, other)
else:
new_op = _INSTRUCTION(op_code, None, other)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

out_expr = ParameterExpression(parameter_symbols, expr, _qpy_replay=new_replay)
out_expr._name_map = self._names.copy()
if isinstance(other, ParameterExpression):
out_expr._names.update(other._names.copy())
Expand All @@ -291,6 +409,13 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
# If it is not contained then return 0
return 0.0

if self._standalone_param:
new_op = _INSTRUCTION(_OPCode.GRAD, self, param)
else:
new_op = _INSTRUCTION(_OPCode.GRAD, None, param)
qpy_replay = self._qpy_replay.copy()
qpy_replay.append(new_op)

# Compute the gradient of the parameter expression w.r.t. param
key = self._parameter_symbols[param]
expr_grad = symengine.Derivative(self._symbol_expr, key)
Expand All @@ -304,7 +429,7 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
parameter_symbols[parameter] = symbol
# If the gradient corresponds to a parameter expression then return the new expression.
if len(parameter_symbols) > 0:
return ParameterExpression(parameter_symbols, expr=expr_grad)
return ParameterExpression(parameter_symbols, expr=expr_grad, _qpy_replay=qpy_replay)
# If no free symbols left, return a complex or float gradient
expr_grad_cplx = complex(expr_grad)
if expr_grad_cplx.imag != 0:
Expand All @@ -313,81 +438,89 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
return float(expr_grad)

def __add__(self, other):
return self._apply_operation(operator.add, other)
return self._apply_operation(operator.add, other, op_code=_OPCode.ADD)

def __radd__(self, other):
return self._apply_operation(operator.add, other, reflected=True)
return self._apply_operation(operator.add, other, reflected=True, op_code=_OPCode.ADD)

def __sub__(self, other):
return self._apply_operation(operator.sub, other)
return self._apply_operation(operator.sub, other, op_code=_OPCode.SUB)

def __rsub__(self, other):
return self._apply_operation(operator.sub, other, reflected=True)
return self._apply_operation(operator.sub, other, reflected=True, op_code=_OPCode.RSUB)

def __mul__(self, other):
return self._apply_operation(operator.mul, other)
return self._apply_operation(operator.mul, other, op_code=_OPCode.MUL)

def __pos__(self):
return self._apply_operation(operator.mul, 1)
return self._apply_operation(operator.mul, 1, op_code=_OPCode.MUL)

def __neg__(self):
return self._apply_operation(operator.mul, -1)
return self._apply_operation(operator.mul, -1, op_code=_OPCode.MUL)

def __rmul__(self, other):
return self._apply_operation(operator.mul, other, reflected=True)
return self._apply_operation(operator.mul, other, reflected=True, op_code=_OPCode.MUL)

def __truediv__(self, other):
if other == 0:
raise ZeroDivisionError("Division of a ParameterExpression by zero.")
return self._apply_operation(operator.truediv, other)
return self._apply_operation(operator.truediv, other, op_code=_OPCode.DIV)

def __rtruediv__(self, other):
return self._apply_operation(operator.truediv, other, reflected=True)
return self._apply_operation(operator.truediv, other, reflected=True, op_code=_OPCode.RDIV)

def __pow__(self, other):
return self._apply_operation(pow, other)
return self._apply_operation(pow, other, op_code=_OPCode.POW)

def __rpow__(self, other):
return self._apply_operation(pow, other, reflected=True)
return self._apply_operation(pow, other, reflected=True, op_code=_OPCode.RPOW)

def _call(self, ufunc):
return ParameterExpression(self._parameter_symbols, ufunc(self._symbol_expr))
def _call(self, ufunc, op_code):
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self)
else:
new_op = _INSTRUCTION(op_code, None)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
return ParameterExpression(
self._parameter_symbols, ufunc(self._symbol_expr), _qpy_replay=new_replay
)

def sin(self):
"""Sine of a ParameterExpression"""
return self._call(symengine.sin)
return self._call(symengine.sin, op_code=_OPCode.SIN)

def cos(self):
"""Cosine of a ParameterExpression"""
return self._call(symengine.cos)
return self._call(symengine.cos, op_code=_OPCode.COS)

def tan(self):
"""Tangent of a ParameterExpression"""
return self._call(symengine.tan)
return self._call(symengine.tan, op_code=_OPCode.TAN)

def arcsin(self):
"""Arcsin of a ParameterExpression"""
return self._call(symengine.asin)
return self._call(symengine.asin, op_code=_OPCode.ASIN)

def arccos(self):
"""Arccos of a ParameterExpression"""
return self._call(symengine.acos)
return self._call(symengine.acos, op_code=_OPCode.ACOS)

def arctan(self):
"""Arctan of a ParameterExpression"""
return self._call(symengine.atan)
return self._call(symengine.atan, op_code=_OPCode.ATAN)

def exp(self):
"""Exponential of a ParameterExpression"""
return self._call(symengine.exp)
return self._call(symengine.exp, op_code=_OPCode.EXP)

def log(self):
"""Logarithm of a ParameterExpression"""
return self._call(symengine.log)
return self._call(symengine.log, op_code=_OPCode.LOG)

def sign(self):
"""Sign of a ParameterExpression"""
return self._call(symengine.sign)
return self._call(symengine.sign, op_code=_OPCode.SIGN)

def __repr__(self):
return f"{self.__class__.__name__}({str(self)})"
Expand Down Expand Up @@ -455,7 +588,7 @@ def __deepcopy__(self, memo=None):

def __abs__(self):
"""Absolute of a ParameterExpression"""
return self._call(symengine.Abs)
return self._call(symengine.Abs, _OPCode.ABS)

def abs(self):
"""Absolute of a ParameterExpression"""
Expand Down
Loading