Source code for oqpy.base

############################################################################
#  Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License").
#  You may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
############################################################################
"""Base classes and conversion methods for OQpy.

This class establishes how expressions are represented in oqpy and how
they are converted to AST nodes.
"""

from __future__ import annotations

import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Iterable, Union

import numpy as np
from openpulse import ast

if sys.version_info >= (3, 8):
    from typing import Protocol, runtime_checkable
else:
    from typing_extensions import Protocol, runtime_checkable

if TYPE_CHECKING:
    from oqpy import Program


[docs]class OQPyExpression: """Base class for OQPy expressions. Subclasses must implement ``to_ast`` method and supply the ``type`` attribute Expressions can be composed via overloaded arithmetic and boolean comparison operators to create new expressions. Note this means you cannot evaluate expression equality via ``==`` which produces a new expression instead of producing a python boolean. """ type: ast.ClassicalType
[docs] def to_ast(self, program: Program) -> ast.Expression: """Converts the oqpy expression into an ast node.""" raise NotImplementedError # pragma: no cover
@staticmethod def _to_binary( op_name: str, first: AstConvertible, second: AstConvertible ) -> OQPyBinaryExpression: """Helper method to produce a binary expression.""" return OQPyBinaryExpression(ast.BinaryOperator[op_name], first, second) def __add__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("+", self, other) def __radd__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("+", other, self) def __mod__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("%", self, other) def __rmod__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("%", other, self) def __mul__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("*", self, other) def __rmul__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("*", other, self) def __eq__(self, other: AstConvertible) -> OQPyBinaryExpression: # type: ignore[override] return self._to_binary("==", self, other) def __ne__(self, other: AstConvertible) -> OQPyBinaryExpression: # type: ignore[override] return self._to_binary("!=", self, other) def __gt__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary(">", self, other) def __lt__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("<", self, other) def __ge__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary(">=", self, other) def __le__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("<=", self, other) def __bool__(self) -> bool: raise RuntimeError( "OQPy expressions cannot be converted to bool. This can occur if you try to check " "the equality of expressions using == instead of expr_matches." )
[docs]def expr_matches(a: Any, b: Any) -> bool: """Check equality of the given objects. This bypasses calling ``__eq__`` on expr objects. """ if type(a) is not type(b): return False if isinstance(a, (list, np.ndarray)): if len(a) != len(b): return False return all(expr_matches(ai, bi) for ai, bi in zip(a, b)) elif isinstance(a, dict): if a.keys() != b.keys(): return False return all(expr_matches(va, b[k]) for k, va in a.items()) if hasattr(a, "__dict__"): return expr_matches(a.__dict__, b.__dict__) else: return a == b
[docs]@runtime_checkable class ExpressionConvertible(Protocol): """This is the protocol an object can implement in order to be usable as an expression.""" def _to_oqpy_expression(self) -> HasToAst: ...
[docs]class OQPyBinaryExpression(OQPyExpression): """An expression consisting of two subexpressions joined by an operator.""" def __init__(self, op: ast.BinaryOperator, lhs: AstConvertible, rhs: AstConvertible): super().__init__() self.op = op self.lhs = lhs self.rhs = rhs # TODO (#50): More robust type checking which considers both arguments # types, as well as the operator. if isinstance(lhs, OQPyExpression): self.type = lhs.type elif isinstance(rhs, OQPyExpression): self.type = rhs.type else: raise TypeError("Neither lhs nor rhs is an expression?")
[docs] def to_ast(self, program: Program) -> ast.BinaryExpression: """Converts the OQpy expression into an ast node.""" return ast.BinaryExpression(self.op, to_ast(program, self.lhs), to_ast(program, self.rhs))
[docs]class Var(ABC): """Abstract base class for both classical and quantum variables.""" def __init__(self, name: str, needs_declaration: bool = True): self.name = name self._needs_declaration = needs_declaration def _var_matches(self, other: Any) -> bool: """Return true if this object represents the same variable as other. Needed because we overload ``==`` for expressions. """ if isinstance(self, OQPyExpression): return expr_matches(self, other) else: return self == other
[docs] @abstractmethod def to_ast(self, program: Program) -> ast.Expression: """Converts the OQpy variable into an ast node.""" ...
[docs] @abstractmethod def make_declaration_statement(self, program: Program) -> ast.Statement: """Make an ast statement that declares the OQpy variable.""" ...
[docs]@runtime_checkable class HasToAst(Protocol): """Protocol for objects which can be converted into ast nodes."""
[docs] def to_ast(self, program: Program) -> ast.Expression: """Converts the OQpy object into an ast node.""" ...
AstConvertible = Union[ HasToAst, bool, int, float, complex, Iterable, ExpressionConvertible, ast.Expression ]
[docs]def to_ast(program: Program, item: AstConvertible) -> ast.Expression: """Convert an object to an AST node.""" if hasattr(item, "_to_oqpy_expression"): return item._to_oqpy_expression().to_ast(program) # type: ignore[union-attr] if isinstance(item, (complex, np.complexfloating)): if item.imag == 0: return to_ast(program, item.real) if item.real == 0: if item.imag < 0: return ast.UnaryExpression(ast.UnaryOperator["-"], ast.ImaginaryLiteral(-item.imag)) else: return ast.ImaginaryLiteral(item.imag) if item.imag < 0: return ast.BinaryExpression( ast.BinaryOperator["-"], ast.FloatLiteral(item.real), ast.ImaginaryLiteral(-item.imag), ) return ast.BinaryExpression( ast.BinaryOperator["+"], ast.FloatLiteral(item.real), ast.ImaginaryLiteral(item.imag) ) if isinstance(item, (bool, np.bool_)): return ast.BooleanLiteral(item) if isinstance(item, (int, np.integer)): if item < 0: return ast.UnaryExpression(ast.UnaryOperator["-"], ast.IntegerLiteral(-item)) return ast.IntegerLiteral(item) if isinstance(item, (float, np.floating)): if item < 0: return ast.UnaryExpression(ast.UnaryOperator["-"], ast.FloatLiteral(-item)) return ast.FloatLiteral(item) if isinstance(item, Iterable): return ast.ArrayLiteral([to_ast(program, i) for i in item]) if isinstance(item, ast.Expression): return item if hasattr(item, "to_ast"): # Using isinstance(HasToAst) slowish return item.to_ast(program) # type: ignore[union-attr] raise TypeError(f"Cannot convert {item} of type {type(item)} to ast")
[docs]def optional_ast(program: Program, item: AstConvertible | None) -> ast.Expression | None: """Convert item to ast if it is not None.""" if item is None: return None return to_ast(program, item)
[docs]def map_to_ast(program: Program, items: Iterable[AstConvertible]) -> list[ast.Expression]: """Convert a sequence of items into a sequence of ast nodes.""" return [to_ast(program, item) for item in items]