Source code for oqpy.base

############################################################################
#  Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License").
#  You may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
############################################################################
"""Base classes and conversion methods for OQpy.

This class establishes how expressions are represented in oqpy and how
they are converted to AST nodes.
"""

from __future__ import annotations

import math
import uuid
from abc import ABC, abstractmethod
from typing import (
    TYPE_CHECKING,
    Any,
    Hashable,
    Iterable,
    Optional,
    Protocol,
    Sequence,
    Union,
    cast,
    runtime_checkable,
)

import numpy as np
from openpulse import ast

from oqpy import classical_types

if TYPE_CHECKING:
    from oqpy import Program


[docs] class OQPyExpression: """Base class for OQPy expressions. Subclasses must implement ``to_ast`` method and supply the ``type`` attribute Expressions can be composed via overloaded arithmetic and boolean comparison operators to create new expressions. Note this means you cannot evaluate expression equality via ``==`` which produces a new expression instead of producing a python boolean. """ type: Optional[ast.ClassicalType]
[docs] def to_ast(self, program: Program) -> ast.Expression: """Converts the oqpy expression into an ast node.""" raise NotImplementedError # pragma: no cover
@staticmethod def _to_binary( op_name: str, first: AstConvertible, second: AstConvertible, result_type: ast.ClassicalType | None = None, ) -> OQPyBinaryExpression: """Helper method to produce a binary expression.""" return OQPyBinaryExpression(ast.BinaryOperator[op_name], first, second, result_type) @staticmethod def _to_unary(op_name: str, exp: AstConvertible) -> OQPyUnaryExpression: """Helper method to produce a binary expression.""" return OQPyUnaryExpression(ast.UnaryOperator[op_name], exp) def __pos__(self) -> OQPyExpression: return self def __neg__(self) -> OQPyUnaryExpression: return self._to_unary("-", self) def __add__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("+", self, other) def __radd__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("+", other, self) def __sub__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("-", self, other) def __rsub__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("-", other, self) def __mod__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("%", self, other) def __rmod__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("%", other, self) def __mul__(self, other: AstConvertible) -> OQPyBinaryExpression: result_type = compute_product_types(self, other) return self._to_binary("*", self, other, result_type) def __rmul__(self, other: AstConvertible) -> OQPyBinaryExpression: result_type = compute_product_types(other, self) return self._to_binary("*", other, self, result_type) def __truediv__(self, other: AstConvertible) -> OQPyBinaryExpression: result_type = compute_quotient_types(self, other) return self._to_binary("/", self, other, result_type) def __rtruediv__(self, other: AstConvertible) -> OQPyBinaryExpression: result_type = compute_quotient_types(other, self) return self._to_binary("/", other, self, result_type) def __pow__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("**", self, other) def __rpow__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("**", other, self) def __lshift__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("<<", self, other) def __rlshift__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("<<", other, self) def __rshift__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary(">>", self, other) def __rrshift__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary(">>", other, self) def __and__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("&", self, other) def __rand__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("&", other, self) def __or__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("|", self, other) def __ror__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("|", other, self) def __xor__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("^", self, other) def __rxor__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("^", other, self) def __eq__(self, other: AstConvertible) -> OQPyBinaryExpression: # type: ignore[override] return self._to_binary("==", self, other) def __ne__(self, other: AstConvertible) -> OQPyBinaryExpression: # type: ignore[override] return self._to_binary("!=", self, other) def __gt__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary(">", self, other) def __lt__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("<", self, other) def __ge__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary(">=", self, other) def __le__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("<=", self, other) def __invert__(self) -> OQPyUnaryExpression: return self._to_unary("~", self) def __bool__(self) -> bool: raise RuntimeError( "OQPy expressions cannot be converted to bool. This can occur if you try to check " "the equality of expressions using == instead of expr_matches." )
def _get_type(val: AstConvertible) -> Optional[ast.ClassicalType]: if isinstance(val, OQPyExpression): return val.type elif isinstance(val, int): return ast.IntType() elif isinstance(val, float): return ast.FloatType() elif isinstance(val, complex): return ast.ComplexType(ast.FloatType()) else: raise ValueError(f"Cannot multiply/divide oqpy expression with with {type(val)}")
[docs] def compute_product_types(left: AstConvertible, right: AstConvertible) -> ast.ClassicalType: """Find the result type for a product of two terms.""" left_type = _get_type(left) right_type = _get_type(right) types_map = { (ast.FloatType, ast.FloatType): left_type, (ast.FloatType, ast.IntType): left_type, (ast.FloatType, ast.UintType): left_type, (ast.FloatType, ast.DurationType): right_type, (ast.FloatType, ast.AngleType): right_type, (ast.FloatType, ast.ComplexType): right_type, (ast.IntType, ast.FloatType): right_type, (ast.IntType, ast.IntType): left_type, (ast.IntType, ast.UintType): left_type, (ast.IntType, ast.DurationType): right_type, (ast.IntType, ast.AngleType): right_type, (ast.IntType, ast.ComplexType): right_type, (ast.UintType, ast.FloatType): right_type, (ast.UintType, ast.IntType): right_type, (ast.UintType, ast.UintType): left_type, (ast.UintType, ast.DurationType): right_type, (ast.UintType, ast.AngleType): right_type, (ast.UintType, ast.ComplexType): right_type, (ast.DurationType, ast.FloatType): left_type, (ast.DurationType, ast.IntType): left_type, (ast.DurationType, ast.UintType): left_type, (ast.DurationType, ast.DurationType): TypeError( "Cannot multiply two durations. You may need to re-group computations to eliminate this." ), (ast.DurationType, ast.AngleType): TypeError("Cannot multiply duration and angle"), (ast.DurationType, ast.ComplexType): TypeError("Cannot multiply duration and complex"), (ast.AngleType, ast.FloatType): left_type, (ast.AngleType, ast.IntType): left_type, (ast.AngleType, ast.UintType): left_type, (ast.AngleType, ast.DurationType): TypeError("Cannot multiply angle and duration"), (ast.AngleType, ast.AngleType): TypeError("Cannot multiply two angles"), (ast.AngleType, ast.ComplexType): TypeError("Cannot multiply angle and complex"), (ast.ComplexType, ast.FloatType): left_type, (ast.ComplexType, ast.IntType): left_type, (ast.ComplexType, ast.UintType): left_type, (ast.ComplexType, ast.DurationType): TypeError("Cannot multiply complex and duration"), (ast.ComplexType, ast.AngleType): TypeError("Cannot multiply complex and angle"), (ast.ComplexType, ast.ComplexType): left_type, } try: result_type = types_map[type(left_type), type(right_type)] except KeyError as e: raise TypeError(f"Could not identify types for product {left} and {right}") from e if isinstance(result_type, Exception): raise result_type return result_type
[docs] def compute_quotient_types(left: AstConvertible, right: AstConvertible) -> ast.ClassicalType: """Find the result type for a quotient of two terms.""" left_type = _get_type(left) right_type = _get_type(right) float_type = ast.FloatType() types_map = { (ast.FloatType, ast.FloatType): left_type, (ast.FloatType, ast.IntType): left_type, (ast.FloatType, ast.UintType): left_type, (ast.FloatType, ast.DurationType): TypeError("Cannot divide float by duration"), (ast.FloatType, ast.AngleType): TypeError("Cannot divide float by angle"), (ast.FloatType, ast.ComplexType): right_type, (ast.IntType, ast.FloatType): right_type, (ast.IntType, ast.IntType): float_type, (ast.IntType, ast.UintType): float_type, (ast.IntType, ast.DurationType): TypeError("Cannot divide int by duration"), (ast.IntType, ast.AngleType): TypeError("Cannot divide int by angle"), (ast.IntType, ast.ComplexType): right_type, (ast.UintType, ast.FloatType): right_type, (ast.UintType, ast.IntType): float_type, (ast.UintType, ast.UintType): float_type, (ast.UintType, ast.DurationType): TypeError("Cannot divide uint by duration"), (ast.UintType, ast.AngleType): TypeError("Cannot divide uint by angle"), (ast.UintType, ast.ComplexType): right_type, (ast.DurationType, ast.FloatType): left_type, (ast.DurationType, ast.IntType): left_type, (ast.DurationType, ast.UintType): left_type, (ast.DurationType, ast.DurationType): ast.FloatType(), (ast.DurationType, ast.AngleType): TypeError("Cannot divide duration by angle"), (ast.DurationType, ast.ComplexType): TypeError("Cannot divide duration by complex"), (ast.AngleType, ast.FloatType): left_type, (ast.AngleType, ast.IntType): left_type, (ast.AngleType, ast.UintType): left_type, (ast.AngleType, ast.DurationType): TypeError("Cannot divide by duration"), (ast.AngleType, ast.AngleType): float_type, (ast.AngleType, ast.ComplexType): TypeError("Cannot divide by angle by complex"), (ast.ComplexType, ast.FloatType): left_type, (ast.ComplexType, ast.IntType): left_type, (ast.ComplexType, ast.UintType): left_type, (ast.ComplexType, ast.DurationType): TypeError("Cannot divide by duration"), (ast.ComplexType, ast.AngleType): TypeError("Cannot divide by angle"), (ast.ComplexType, ast.ComplexType): left_type, } try: result_type = types_map[type(left_type), type(right_type)] except KeyError as e: raise TypeError(f"Could not identify types for quotient {left} and {right}") from e if isinstance(result_type, Exception): raise result_type return result_type
[docs] def logical_and(first: AstConvertible, second: AstConvertible) -> OQPyBinaryExpression: """Logical AND.""" return OQPyBinaryExpression(ast.BinaryOperator["&&"], first, second)
[docs] def logical_or(first: AstConvertible, second: AstConvertible) -> OQPyBinaryExpression: """Logical OR.""" return OQPyBinaryExpression(ast.BinaryOperator["||"], first, second)
[docs] def expr_matches(a: Any, b: Any) -> bool: """Check equality of the given objects. This bypasses calling ``__eq__`` on expr objects. """ if type(a) is not type(b): return False if isinstance(a, (list, np.ndarray)): if len(a) != len(b): return False return all(expr_matches(ai, bi) for ai, bi in zip(a, b)) elif isinstance(a, dict): if a.keys() != b.keys(): return False return all(expr_matches(va, b[k]) for k, va in a.items()) if hasattr(a, "__dict__"): return expr_matches(a.__dict__, b.__dict__) else: return a == b
[docs] @runtime_checkable class ExpressionConvertible(Protocol): """This is the protocol an object can implement in order to be usable as an expression.""" def _to_oqpy_expression(self) -> HasToAst: ... # pragma: no cover
[docs] @runtime_checkable class CachedExpressionConvertible(Protocol): """This is the protocol an object can implement in order to be usable as an expression. The difference between this and `ExpressionConvertible` is that this requires that the result of `_to_cached_oqpy_expression` be constant across the lifetime of the OQPy Program. OQPy makes an effort to minimize the number of calls to the AST constructor, but no guarantees are made about this. """ _oqpy_cache_key: Hashable def _to_cached_oqpy_expression(self) -> HasToAst: ... # pragma: no cover
[docs] class OQPyUnaryExpression(OQPyExpression): """An expression consisting of one expression preceded by an operator.""" def __init__(self, op: ast.UnaryOperator, exp: AstConvertible): super().__init__() self.op = op self.exp = exp if isinstance(exp, OQPyExpression): self.type = exp.type else: raise TypeError("exp is not an expression")
[docs] def to_ast(self, program: Program) -> ast.UnaryExpression: """Converts the OQpy expression into an ast node.""" return ast.UnaryExpression(self.op, to_ast(program, self.exp))
[docs] class OQPyBinaryExpression(OQPyExpression): """An expression consisting of two subexpressions joined by an operator.""" def __init__( self, op: ast.BinaryOperator, lhs: AstConvertible, rhs: AstConvertible, ast_type: ast.ClassicalType | None = None, ): super().__init__() self.op = op self.lhs = lhs self.rhs = rhs # TODO (#9): More robust type checking which considers both arguments # types, as well as the operator. if ast_type is None: if isinstance(lhs, OQPyExpression): ast_type = lhs.type elif isinstance(rhs, OQPyExpression): ast_type = rhs.type else: raise TypeError("Neither lhs nor rhs is an expression?") self.type = ast_type # Adding floats to durations is not allowed. So we promote types as necessary. if isinstance(self.type, ast.DurationType) and self.op in [ ast.BinaryOperator["+"], ast.BinaryOperator["-"], ]: # Late import to avoid circular imports. from oqpy.timing import convert_float_to_duration self.lhs = convert_float_to_duration(self.lhs) self.rhs = convert_float_to_duration(self.rhs)
[docs] def to_ast(self, program: Program) -> ast.BinaryExpression: """Converts the OQpy expression into an ast node.""" return ast.BinaryExpression(self.op, to_ast(program, self.lhs), to_ast(program, self.rhs))
[docs] class Var(ABC): """Abstract base class for both classical and quantum variables.""" def __init__(self, name: str, needs_declaration: bool = True): self.name = name self._needs_declaration = needs_declaration def _var_matches(self, other: Any) -> bool: """Return true if this object represents the same variable as other. Needed because we overload ``==`` for expressions. """ if isinstance(self, OQPyExpression): return expr_matches(self, other) else: return self == other
[docs] @abstractmethod def to_ast(self, program: Program) -> ast.Expression: """Converts the OQpy variable into an ast node.""" ...
[docs] @abstractmethod def make_declaration_statement(self, program: Program) -> ast.Statement: """Make an ast statement that declares the OQpy variable.""" ...
[docs] @runtime_checkable class HasToAst(Protocol): """Protocol for objects which can be converted into ast nodes."""
[docs] def to_ast(self, program: Program) -> ast.Expression: """Converts the OQpy object into an ast node.""" ... # pragma: no cover
AstConvertible = Union[ HasToAst, bool, int, float, complex, Iterable, ExpressionConvertible, CachedExpressionConvertible, ast.Expression, ]
[docs] def to_ast(program: Program, item: AstConvertible) -> ast.Expression: """Convert an object to an AST node.""" if hasattr(item, "_to_oqpy_expression"): item = cast(ExpressionConvertible, item) return item._to_oqpy_expression().to_ast(program) if hasattr(item, "_to_cached_oqpy_expression"): item = cast(CachedExpressionConvertible, item) if item._oqpy_cache_key is None: item._oqpy_cache_key = uuid.uuid1() if item._oqpy_cache_key not in program.expr_cache: program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression().to_ast( program ) return program.expr_cache[item._oqpy_cache_key] if isinstance(item, (complex, np.complexfloating)): if item.imag == 0: return to_ast(program, item.real) if item.real == 0: if item.imag < 0: return ast.UnaryExpression(ast.UnaryOperator["-"], ast.ImaginaryLiteral(-item.imag)) else: return ast.ImaginaryLiteral(item.imag) if item.imag < 0: return ast.BinaryExpression( ast.BinaryOperator["-"], ast.FloatLiteral(item.real), ast.ImaginaryLiteral(-item.imag), ) return ast.BinaryExpression( ast.BinaryOperator["+"], ast.FloatLiteral(item.real), ast.ImaginaryLiteral(item.imag) ) if isinstance(item, (bool, np.bool_)): return ast.BooleanLiteral(item) if isinstance(item, (int, np.integer)): item = int(item) if item < 0: return ast.UnaryExpression(ast.UnaryOperator["-"], ast.IntegerLiteral(-item)) return ast.IntegerLiteral(item) if isinstance(item, (float, np.floating)): if item < 0: if program.simplify_constants: neg_ast_term = detect_and_convert_constants(-item, program) else: neg_ast_term = ast.FloatLiteral(-item) return ast.UnaryExpression(ast.UnaryOperator["-"], neg_ast_term) if program.simplify_constants: return detect_and_convert_constants(item, program) return ast.FloatLiteral(item) if isinstance(item, slice): return ast.RangeDefinition( to_ast(program, item.start) if item.start is not None else None, to_ast(program, item.stop - 1) if item.stop is not None else None, to_ast(program, item.step) if item.step is not None else None, ) if isinstance(item, Iterable): return ast.ArrayLiteral([to_ast(program, i) for i in item]) if isinstance(item, ast.Expression): return item if hasattr(item, "to_ast"): # Using isinstance(HasToAst) slowish return item.to_ast(program) raise TypeError(f"Cannot convert {item} of type {type(item)} to ast")
[docs] def optional_ast(program: Program, item: AstConvertible | None) -> ast.Expression | None: """Convert item to ast if it is not None.""" if item is None: return None return to_ast(program, item)
[docs] def map_to_ast(program: Program, items: Iterable[AstConvertible]) -> list[ast.Expression]: """Convert a sequence of items into a sequence of ast nodes.""" return [to_ast(program, item) for item in items]
[docs] def make_annotations(vals: Sequence[str | tuple[str, str]]) -> list[ast.Annotation]: """Convert strings/tuples of strings into Annotation ast nodes.""" anns: list[ast.Annotation] = [] for val in vals: if isinstance(val, str): anns.append(ast.Annotation(val)) else: keyword, command = val anns.append(ast.Annotation(keyword, command)) return anns
[docs] def detect_and_convert_constants(val: float | np.floating[Any], program: Program) -> ast.Expression: """Construct a float ast expression which is either a literal or an expression using constants.""" if val == 0: return ast.FloatLiteral(val) if val < 0.5 or val > 100: return ast.FloatLiteral(val) x = val / (math.pi / 4.0) rx = round(x) if not math.isclose(x, rx, rel_tol=1e-12): return ast.FloatLiteral(val) term: OQPyExpression if rx == 4: term = classical_types.pi elif rx == 2: term = classical_types.pi / 2 elif rx == 1: term = classical_types.pi / 4 elif rx % 4 == 0: term = (rx // 4) * classical_types.pi elif rx % 2 == 0: term = (rx // 2) * classical_types.pi / 2 else: term = rx * classical_types.pi / 4 return term.to_ast(program)