Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions mpqp/execution/connection/aws_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,11 @@ def get_aws_braket_account_info() -> str:
return result


def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevice":
def get_braket_device(
device: AWSDevice,
is_noisy: bool = False,
is_gate_model: bool = True,
) -> "BraketDevice":
"""Returns the AwsDevice device associate with the AWSDevice in parameter.

Args:
Expand All @@ -378,6 +382,11 @@ def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevic
"""
from braket.devices import LocalSimulator

from mpqp.tools.errors import (
AWSBraketRemoteExecutionError,
DeviceJobIncompatibleError,
)

if not device.is_remote():
if is_noisy:
return LocalSimulator("braket_dm")
Expand All @@ -397,7 +406,8 @@ def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevic
aws_session.add_braket_user_agent(
user_agent="APN/1.0 ColibriTD/1.0 MPQP/" + mpqp_version
)
return AwsDevice(device.get_arn(), aws_session=aws_session)
braket_device = AwsDevice(device.get_arn(), aws_session=aws_session)

except ValueError as ve:
raise AWSBraketRemoteExecutionError(
"Failed to retrieve remote AWS device. Please check the arn, or if the "
Expand All @@ -410,6 +420,23 @@ def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevic
"\nTrace: " + str(err)
)

if is_gate_model:
actions = getattr(getattr(braket_device, "properties", None), "action", None)
if actions is not None:
supported = [getattr(k, "value", str(k)) for k in actions.keys()]
supports_gate_model = any(
("openqasm" in action.lower()) or ("jaqcd" in action.lower())
for action in supported
)
if not supports_gate_model:
raise DeviceJobIncompatibleError(
f"{device.name} does not support gate-model workloads. "
f"Supported Braket action types: {supported}. "
"This is an AHS device, which cannot run MPQP QCircuit."
)

return braket_device


def get_all_task_ids() -> list[str]:
"""Retrieves all the task ids of this account/group from AWS.
Expand Down
38 changes: 29 additions & 9 deletions mpqp/execution/providers/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from mpqp.execution.job import Job, JobStatus, JobType
from mpqp.execution.result import Result, Sample, StateVector
from mpqp.noise.noise_model import NoiseModel
from mpqp.tools.errors import AWSBraketRemoteExecutionError, DeviceJobIncompatibleError
from mpqp.tools.errors import (
AWSBraketRemoteExecutionError,
DeviceJobIncompatibleError,
DeviceJobIncompatibleWarning,
)

if TYPE_CHECKING:
from braket.circuits import Circuit
Expand Down Expand Up @@ -109,16 +113,31 @@ def run_braket(job: Job) -> Result:
f"{job.device} instead"
)

import warnings

from braket.tasks import GateModelQuantumTaskResult

if isinstance(job.measure, ExpectationMeasure):
return run_braket_observable(job)
_, task = submit_job_braket(job)
res = task.result()
if TYPE_CHECKING:
assert isinstance(res, GateModelQuantumTaskResult)
try:
if isinstance(job.measure, ExpectationMeasure):
return run_braket_observable(job)

_, task = submit_job_braket(job)
res = task.result()
if TYPE_CHECKING:
assert isinstance(res, GateModelQuantumTaskResult)

return extract_result(res, job, job.device)
return extract_result(res, job, job.device)

except DeviceJobIncompatibleError as e:
warnings.warn(str(e), DeviceJobIncompatibleWarning, stacklevel=5)

job.status = JobStatus.ERROR
return Result(
job,
data=None,
errors="Unsupported Braket backend for QCircuit (see warning).",
shots=0,
)


def run_braket_observable(job: Job):
Expand Down Expand Up @@ -151,6 +170,7 @@ def run_braket_observable(job: Job):
job.device,
is_noisy=bool(job.circuit.noises),
)

if job.measure is None:
raise NotImplementedError("job.measure is None")
assert isinstance(job.measure, ExpectationMeasure)
Expand Down Expand Up @@ -270,7 +290,7 @@ def run_braket_observable(job: Job):
)

if braket_sum is not None:
from braket.program_sets import ProgramSet, CircuitBinding
from braket.program_sets import CircuitBinding, ProgramSet
from braket.tasks.program_set_quantum_task_result import (
ProgramSetQuantumTaskResult,
)
Expand Down
14 changes: 11 additions & 3 deletions mpqp/execution/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import numpy.typing as npt

from mpqp.core.instruction.measurement.basis_measure import BasisMeasure
from mpqp.execution import Job, JobType
from mpqp.execution import Job, JobStatus, JobType
from mpqp.execution.devices import AvailableDevice
from mpqp.tools.display import clean_1D_array, clean_number_repr
from mpqp.tools.errors import ResultAttributeError
Expand Down Expand Up @@ -288,8 +288,8 @@ class Result:
def __init__(
self,
job: Job,
data: float | dict["str", float] | StateVector | list[Sample],
errors: Optional[float | dict[Any, Any]] = None,
data: float | dict["str", float] | StateVector | list[Sample] | None,
errors: Optional[float | dict[Any, Any] | str] = None,
shots: int = 0,
):
self.job = job
Expand All @@ -305,6 +305,11 @@ def __init__(
"""See parameter description."""
self._data = data

if data is None:
if job.status != JobStatus.ERROR:
raise TypeError("Result data cannot be None unless job.status == ERROR")
return

# depending on the type of job, fills the result info from the data in parameter
if job.job_type == JobType.OBSERVABLE:
if not isinstance(data, float) and not isinstance(data, dict):
Expand Down Expand Up @@ -458,6 +463,9 @@ def __str__(self):
label = "" if self.job.circuit.label is None else self.job.circuit.label + ", "
header = f"Result: {label}{type(self.device).__name__}, {self.device.name}"

if self.job.status == JobStatus.ERROR:
return f"{header}\n Error: {self.error}"

if self.job.job_type == JobType.SAMPLE:
measures = self.job.circuit.measurements
if not len(measures) == 1:
Expand Down
4 changes: 4 additions & 0 deletions mpqp/tools/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class DeviceJobIncompatibleError(ValueError):
for the selected device (for example SAMPLE job on a statevector simulator)."""


class DeviceJobIncompatibleWarning(UserWarning):
"""A warning is issued when a job is not compatible with the selected device."""


class RemoteExecutionError(ConnectionError):
"""Raised when an error occurred during a remote connection, submission or
execution."""
Expand Down