############################################################################
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
############################################################################
"""Classes representing oqpy variables with classical types."""
from __future__ import annotations
import functools
import random
import string
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
from openpulse import ast
from oqpy.base import (
AstConvertible,
OQPyExpression,
Var,
make_annotations,
map_to_ast,
optional_ast,
to_ast,
)
from oqpy.timing import convert_float_to_duration
if TYPE_CHECKING:
from typing import Literal
from oqpy.program import Program
if sys.version_info < (3, 10):
EllipsisType = type(Ellipsis)
else:
from types import EllipsisType
__all__ = [
"pi",
"ArrayVar",
"BoolVar",
"IntVar",
"UintVar",
"FloatVar",
"AngleVar",
"BitVar",
"ComplexVar",
"DurationVar",
"OQFunctionCall",
"OQIndexExpression",
"StretchVar",
"_ClassicalVar",
"duration",
"stretch",
"bool_",
"bit_",
"bit",
"bit8",
"convert_range",
"int_",
"int32",
"int64",
"uint_",
"uint32",
"uint64",
"float_",
"float32",
"float64",
"complex_",
"complex64",
"complex128",
"angle_",
"angle32",
"arrayreference_",
]
# The following methods and constants are useful for creating signatures
# for openqasm function calls, as is required when specifying
# waveform generating methods.
# If you wish to create a variable with a particular type, please use the
# subclasses of ``_ClassicalVar`` instead.
[docs]
def int_(size: int | None = None) -> ast.IntType:
"""Create a sized signed integer type."""
return ast.IntType(ast.IntegerLiteral(size) if size is not None else None)
[docs]
def uint_(size: int | None = None) -> ast.UintType:
"""Create a sized unsigned integer type."""
return ast.UintType(ast.IntegerLiteral(size) if size is not None else None)
[docs]
def float_(size: int | None = None) -> ast.FloatType:
"""Create a sized floating-point type."""
return ast.FloatType(ast.IntegerLiteral(size) if size is not None else None)
[docs]
def angle_(size: int | None = None) -> ast.AngleType:
"""Create a sized angle type."""
return ast.AngleType(ast.IntegerLiteral(size) if size is not None else None)
[docs]
def complex_(size: int) -> ast.ComplexType:
"""Create a sized complex type.
Note the size represents the total size, and thus the components have
half of the requested size.
"""
return ast.ComplexType(ast.FloatType(ast.IntegerLiteral(size // 2)))
[docs]
def bit_(size: int | None = None) -> ast.BitType:
"""Create a sized bit type."""
return ast.BitType(ast.IntegerLiteral(size) if size is not None else None)
[docs]
def arrayreference_(
dtype: Union[
ast.IntType,
ast.UintType,
ast.FloatType,
ast.AngleType,
ast.DurationType,
ast.BitType,
ast.BoolType,
ast.ComplexType,
],
dims: int | list[int],
) -> ast.ArrayReferenceType:
"""Create an array reference type."""
dim = (
ast.IntegerLiteral(dims) if isinstance(dims, int) else [ast.IntegerLiteral(d) for d in dims]
)
return ast.ArrayReferenceType(base_type=dtype, dimensions=dim)
duration = ast.DurationType()
stretch = ast.StretchType()
bool_ = ast.BoolType()
bit = ast.BitType()
bit8 = bit_(8)
int32 = int_(32)
int64 = int_(64)
uint32 = uint_(32)
uint64 = uint_(64)
float32 = float_(32)
float64 = float_(64)
complex64 = complex_(64)
complex128 = complex_(128)
angle32 = angle_(32)
[docs]
def convert_range(program: Program, item: Union[slice, range]) -> ast.RangeDefinition:
"""Convert a slice or range into an ast node."""
return ast.RangeDefinition(
to_ast(program, item.start),
to_ast(program, item.stop - 1),
to_ast(program, item.step) if item.step != 1 else None,
)
class Identifier(OQPyExpression):
"""Base class to specify constant symbols."""
name: str
def __init__(self, name: str, ast_type: ast.ClassicalType) -> None:
self.name = name
self.type = ast_type
def to_ast(self, program: Program) -> ast.Expression:
return ast.Identifier(name=self.name)
pi = Identifier(name="pi", ast_type=ast.FloatType())
class _ClassicalVar(Var, OQPyExpression):
"""Base type for variables with classical type.
Subclasses should supply the type_cls class variable.
"""
type_cls: Type[ast.ClassicalType]
def __init__(
self,
init_expression: AstConvertible | Literal["input", "output"] | None = None,
name: str | None = None,
needs_declaration: bool = True,
annotations: Sequence[str | tuple[str, str]] = (),
**type_kwargs: Any,
):
name = name or "".join([random.choice(string.ascii_letters) for _ in range(10)])
super().__init__(name, needs_declaration=needs_declaration)
self.type = self.type_cls(**type_kwargs)
self.init_expression = init_expression
self.annotations = annotations
def to_ast(self, program: Program) -> ast.Identifier:
"""Converts the OQpy variable into an ast node."""
program._add_var(self)
return ast.Identifier(self.name)
def make_declaration_statement(self, program: Program) -> ast.Statement:
"""Make an ast statement that declares the OQpy variable."""
if isinstance(self.init_expression, str) and self.init_expression in ("input", "output"):
return ast.IODeclaration(
ast.IOKeyword[self.init_expression], self.type, self.to_ast(program)
)
init_expression_ast = optional_ast(program, self.init_expression)
stmt = ast.ClassicalDeclaration(self.type, self.to_ast(program), init_expression_ast)
stmt.annotations = make_annotations(self.annotations)
return stmt
[docs]
class BoolVar(_ClassicalVar):
"""An (unsized) oqpy variable with bool type."""
type_cls = ast.BoolType
class _SizedVar(_ClassicalVar):
"""Base class for variables with a specified size."""
default_size: int | None = None
size: int | None
def __class_getitem__(cls: Type[_SizedVarT], item: int | None) -> Callable[..., _SizedVarT]:
# Allows IntVar[64]() notation
return functools.partial(cls, size=item)
def __init__(self, *args: Any, size: int | None | EllipsisType = ..., **kwargs: Any):
if size is ...:
self.size = self.default_size
elif size is None:
self.size = size
else:
if not isinstance(size, int) or size <= 0:
raise ValueError(
f"The size of '{self.type_cls}' objects must be an positive integer."
)
self.size = size
super().__init__(*args, **kwargs, size=ast.IntegerLiteral(self.size) if self.size else None)
def _validate_getitem_index(self, index: AstConvertible) -> None:
"""Validate the index and variable for `__getitem__`.
Args:
var (_SizedVar): Variable to apply `__getitem__`.
index (AstConvertible): Index for `__getitem__`.
"""
if self.size is None:
raise TypeError(f"'{self.name}' is not subscriptable")
if isinstance(index, int):
if not 0 <= index < self.size:
raise IndexError("list index out of range.")
elif isinstance(index, OQPyExpression):
if not isinstance(index.type, (ast.IntType, ast.UintType)):
raise IndexError("The list index must be an integer.")
else:
raise IndexError("The list index must be an integer.")
_SizedVarT = TypeVar("_SizedVarT", bound=_SizedVar)
[docs]
class IntVar(_SizedVar):
"""An oqpy variable with integer type."""
type_cls = ast.IntType
default_size = 32
[docs]
class UintVar(_SizedVar):
"""An oqpy variable with unsigned integer type."""
type_cls = ast.UintType
default_size = 32
[docs]
class FloatVar(_SizedVar):
"""An oqpy variable with floating type."""
type_cls = ast.FloatType
default_size = 64
[docs]
class AngleVar(_SizedVar):
"""An oqpy variable with angle type."""
type_cls = ast.AngleType
default_size = 32
[docs]
class BitVar(_SizedVar):
"""An oqpy variable with bit type."""
type_cls = ast.BitType
def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
self._validate_getitem_index(index)
return OQIndexExpression(collection=self, index=index, type_=self.type_cls())
[docs]
class ComplexVar(_ClassicalVar):
"""An oqpy variable with bit type."""
type_cls = ast.ComplexType
base_type: ast.FloatType = float64
def __class_getitem__(cls, item: ast.FloatType) -> Callable[..., ComplexVar]:
return functools.partial(cls, base_type=item)
def __init__(
self,
init_expression: AstConvertible | Literal["input", "output"] | None = None,
*args: Any,
base_type: ast.FloatType = float64,
**kwargs: Any,
) -> None:
assert isinstance(base_type, ast.FloatType)
self.base_type = base_type
if not isinstance(init_expression, (complex, type(None), str, OQPyExpression)):
init_expression = complex(init_expression) # type: ignore[arg-type]
super().__init__(init_expression, *args, **kwargs, base_type=base_type)
[docs]
class DurationVar(_ClassicalVar):
"""An oqpy variable with duration type."""
type_cls = ast.DurationType
def __init__(
self,
init_expression: AstConvertible | Literal["input", "output"] | None = None,
name: str | None = None,
*args: Any,
**type_kwargs: Any,
) -> None:
if init_expression is not None and not isinstance(init_expression, str):
init_expression = convert_float_to_duration(init_expression)
super().__init__(init_expression, name, *args, **type_kwargs)
[docs]
class StretchVar(_ClassicalVar):
"""An oqpy variable with stretch type."""
type_cls = ast.StretchType
AllowedArrayTypes = Union[_SizedVar, DurationVar, BoolVar, ComplexVar]
[docs]
class ArrayVar(_ClassicalVar):
"""An oqpy array variable."""
type_cls = ast.ArrayType
dimensions: list[int]
base_type: type[AllowedArrayTypes]
def __class_getitem__(
cls, item: tuple[type[AllowedArrayTypes], int] | type[AllowedArrayTypes]
) -> Callable[..., ArrayVar]:
# Allows usage like ArrayVar[FloatVar, 32](...) or ArrayVar[FloatVar]
if isinstance(item, tuple):
base_type = item[0]
dimensions = list(item[1:])
return functools.partial(cls, dimensions=dimensions, base_type=base_type)
else:
return functools.partial(cls, base_type=item)
def __init__(
self,
*args: Any,
dimensions: list[int],
base_type: type[AllowedArrayTypes] = IntVar,
**kwargs: Any,
) -> None:
self.dimensions = dimensions
self.base_type = base_type
# Creating a dummy variable supports IntVar[64] etc.
base_type_instance = base_type()
if isinstance(base_type_instance, _SizedVar):
array_base_type = base_type_instance.type_cls(
size=ast.IntegerLiteral(base_type_instance.size)
)
elif isinstance(base_type_instance, ComplexVar):
array_base_type = base_type_instance.type_cls(base_type=base_type_instance.base_type)
else:
array_base_type = base_type_instance.type_cls()
# Automatically handle Duration array.
if base_type is DurationVar and kwargs["init_expression"] is not None:
kwargs["init_expression"] = (
convert_float_to_duration(i) for i in kwargs["init_expression"]
)
super().__init__(
*args,
**kwargs,
dimensions=[ast.IntegerLiteral(dimension) for dimension in dimensions],
base_type=array_base_type,
)
def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
return OQIndexExpression(collection=self, index=index, type_=self.base_type().type_cls())
[docs]
class OQIndexExpression(OQPyExpression):
"""An oqpy expression corresponding to an index expression."""
def __init__(self, collection: AstConvertible, index: AstConvertible, type_: ast.ClassicalType):
self.collection = collection
self.index = index
self.type = type_
[docs]
def to_ast(self, program: Program) -> ast.IndexExpression:
"""Converts this oqpy index expression into an ast node."""
return ast.IndexExpression(
collection=to_ast(program, self.collection), index=[to_ast(program, self.index)]
)
[docs]
class OQFunctionCall(OQPyExpression):
"""An oqpy expression corresponding to a function call."""
def __init__(
self,
identifier: Union[str, ast.Identifier],
args: Union[Iterable[AstConvertible], dict[Any, AstConvertible]],
return_type: Optional[ast.ClassicalType],
extern_decl: ast.ExternDeclaration | None = None,
subroutine_decl: ast.SubroutineDefinition | None = None,
):
"""Create a new OQFunctionCall instance.
Args:
identifier: The function name.
args: The function arguments. If passed as a dict, the values are used when
creating the FunctionCall ast node.
return_type: The type returned by the function call. If none, returns nothing.
extern_decl: An optional extern declaration ast node. If present,
this extern declaration will be added to the top of the program
whenever this is converted to ast.
subroutine_decl: An optional subroutine definition ast node. If present,
this subroutine definition will be added to the top of the program
whenever this expression is converted to ast.
"""
super().__init__()
if isinstance(identifier, str):
identifier = ast.Identifier(identifier)
self.identifier = identifier
self.args = args
self.type = return_type
self.extern_decl = extern_decl
self.subroutine_decl = subroutine_decl
[docs]
def to_ast(self, program: Program) -> ast.Expression:
"""Converts the OQpy expression into an ast node."""
if self.extern_decl is not None:
program.externs[self.identifier.name] = self.extern_decl
if self.subroutine_decl is not None:
program._add_subroutine(self.identifier.name, self.subroutine_decl)
args = self.args.values() if isinstance(self.args, dict) else self.args
return ast.FunctionCall(self.identifier, map_to_ast(program, args))