Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VJP/JVP capabilities to DefaultQubit2 #4374

Merged
merged 22 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 266 additions & 1 deletion pennylane/devices/experimental/default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

from functools import partial
from numbers import Number
from typing import Union, Callable, Tuple, Optional, Sequence
import concurrent.futures
import os
Expand All @@ -31,7 +32,7 @@
from .execution_config import ExecutionConfig, DefaultExecutionConfig
from ..qubit.simulate import simulate
from ..qubit.preprocess import preprocess, validate_and_expand_adjoint
from ..qubit.adjoint_jacobian import adjoint_jacobian
from ..qubit.adjoint_jacobian import adjoint_jacobian, adjoint_vjp, adjoint_jvp

Result_or_ResultBatch = Union[Result, ResultBatch]
QuantumTapeBatch = Sequence[QuantumTape]
Expand Down Expand Up @@ -283,6 +284,270 @@ def compute_derivatives(
f"{self.name} cannot compute derivatives via {execution_config.gradient_method}"
)

def execute_and_compute_derivatives(
self,
circuits: QuantumTape_or_Batch,
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]

if self.tracker.active:
for c in circuits:
self.tracker.update(resources=c.specs["resources"])
self.tracker.update(batches=1, executions=len(circuits))
self.tracker.update(derivative_batches=1, derivatives=len(circuits))
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
self.tracker.record()

if execution_config.gradient_method != "adjoint":
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
f"{self.name} cannot compute derivatives via {execution_config.gradient_method}"
)

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
results = tuple(
simulate(c, rng=self._rng, debugger=self._debugger, return_final_state=True)
for c in circuits
)
jacs = tuple(adjoint_jacobian(c, state=r[1]) for c, r in zip(circuits, results))
results = tuple(r[0] for r in results)
else:
self._validate_multiprocessing_circuits(circuits)

vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

def wrapper(c, rng):
res, final_state, _ = simulate(c, rng=rng, debugger=None, return_final_state=True)
jac = adjoint_jacobian(c, state=final_state)
return res, jac

with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
results = tuple(executor.map(wrapper, vanilla_circuits, seeds))
results, jacs = tuple(zip(*results))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return (results[0], jacs[0]) if is_single_circuit else (results, jacs)

def supports_jvp(
self,
execution_config: Optional[ExecutionConfig] = None,
circuit: Optional[QuantumTape] = None,
) -> bool:
"""Whether or not this device defines a custom jacobian vector product.

``DefaultQubit2`` supports backpropagation derivatives with analytic results, as well as
adjoint differentiation.

Args:
execution_config (ExecutionConfig): The configuration of the desired derivative calculation
circuit (QuantumTape): An optional circuit to check derivatives support for.

Returns:
bool: Whether or not a derivative can be calculated provided the given information
"""
return self.supports_derivatives(execution_config, circuit)

def compute_jvp(
self,
circuits: QuantumTape_or_Batch,
tangents: Tuple[Number],
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]
tangents = [tangents]

if self.tracker.active:
self.tracker.update(derivative_batches=1, derivatives=len(circuits))
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
self.tracker.record()

if execution_config.gradient_method != "adjoint":
raise NotImplementedError(
f"{self.name} cannot compute derivatives via {execution_config.gradient_method}"
)

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
res = tuple(adjoint_jvp(circuit, tans) for circuit, tans in zip(circuits, tangents))
else:
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
res = tuple(executor.map(adjoint_jvp, vanilla_circuits, tangents))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return res[0] if is_single_circuit else res

def execute_and_compute_jvp(
self,
circuits: QuantumTape_or_Batch,
tangents: Tuple[Number],
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]
tangents = [tangents]

if self.tracker.active:
for c in circuits:
self.tracker.update(resources=c.specs["resources"])
self.tracker.update(batches=1, executions=len(circuits))
self.tracker.update(derivative_batches=1, derivatives=len(circuits))
self.tracker.record()

if execution_config.gradient_method != "adjoint":
raise NotImplementedError(
f"{self.name} cannot compute derivatives via {execution_config.gradient_method}"
)

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
results = tuple(
simulate(c, rng=self._rng, debugger=self._debugger, return_final_state=True)
for c in circuits
)
jvps = tuple(
adjoint_jvp(c, t, state=r[1]) for c, t, r in zip(circuits, tangents, results)
)
results = tuple(r[0] for r in results)
else:
self._validate_multiprocessing_circuits(circuits)

vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

def wrapper(c, t, rng):
res, final_state, _ = simulate(c, rng=rng, debugger=None, return_final_state=True)
jvp = adjoint_jvp(c, t, state=final_state)
return res, jvp

with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
results = tuple(executor.map(wrapper, vanilla_circuits, tangents, seeds))
results, jvps = tuple(zip(*results))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return (results[0], jvps[0]) if is_single_circuit else (results, jvps)

def supports_vjp(
self,
execution_config: Optional[ExecutionConfig] = None,
circuit: Optional[QuantumTape] = None,
) -> bool:
"""Whether or not this device defines a custom vector jacobian product.

``DefaultQubit2`` supports backpropagation derivatives with analytic results, as well as
adjoint differentiation.

Args:
execution_config (ExecutionConfig): A description of the hyperparameters for the desired computation.
circuit (None, QuantumTape): A specific circuit to check differentation for.

Returns:
bool: Whether or not a derivative can be calculated provided the given information
"""
return self.supports_derivatives(execution_config, circuit)

def compute_vjp(
self,
circuits: QuantumTape_or_Batch,
cotangents: Tuple[Number],
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]
cotangents = [cotangents]

if self.tracker.active:
self.tracker.update(derivative_batches=1, derivatives=len(circuits))
self.tracker.record()

if execution_config.gradient_method != "adjoint":
raise NotImplementedError(
f"{self.name} cannot compute derivatives via {execution_config.gradient_method}"
)

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
res = tuple(adjoint_vjp(circuit, cots) for circuit, cots in zip(circuits, cotangents))
else:
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
res = tuple(executor.map(adjoint_vjp, vanilla_circuits, cotangents))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return res[0] if is_single_circuit else res

def execute_and_compute_vjp(
self,
circuits: QuantumTape_or_Batch,
cotangents: Tuple[Number],
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]
cotangents = [cotangents]

if self.tracker.active:
for c in circuits:
self.tracker.update(resources=c.specs["resources"])
self.tracker.update(batches=1, executions=len(circuits))
self.tracker.update(derivative_batches=1, derivatives=len(circuits))
self.tracker.record()

if execution_config.gradient_method != "adjoint":
raise NotImplementedError(
f"{self.name} cannot compute derivatives via {execution_config.gradient_method}"
)

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
results = tuple(
simulate(c, rng=self._rng, debugger=self._debugger, return_final_state=True)
for c in circuits
)
vjps = tuple(
adjoint_vjp(c, t, state=r[1]) for c, t, r in zip(circuits, cotangents, results)
)
results = tuple(r[0] for r in results)
else:
self._validate_multiprocessing_circuits(circuits)

vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

def wrapper(c, t, rng):
res, final_state, _ = simulate(c, rng=rng, debugger=None, return_final_state=True)
vjp = adjoint_vjp(c, t, state=final_state)
return res, vjp

with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
results = tuple(executor.map(wrapper, vanilla_circuits, cotangents, seeds))
results, vjps = tuple(zip(*results))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return (results[0], vjps[0]) if is_single_circuit else (results, vjps)

# pylint: disable=missing-function-docstring
def _get_max_workers(self, execution_config=None):
max_workers = None
Expand Down
33 changes: 13 additions & 20 deletions pennylane/devices/qubit/adjoint_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from pennylane.tape import QuantumTape

from .apply_operation import apply_operation
from .initialize_state import create_initial_state

Check notice on line 25 in pennylane/devices/qubit/adjoint_jacobian.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/devices/qubit/adjoint_jacobian.py#L25

Unused create_initial_state imported from initialize_state (unused-import)
from .simulate import _final_state

# pylint: disable=protected-access, too-many-branches

Expand All @@ -34,21 +35,7 @@
return qml.math.real(qml.math.sum(qml.math.conj(bra) * ket, axis=sum_axes))


def _get_output_ket(tape):
"""Helper function to get the output state of a tape"""

# Initialization of state
prep_operation = tape[0] if isinstance(tape[0], qml.operation.StatePrep) else None
ket = create_initial_state(
wires=tape.wires, prep_operation=prep_operation
) # ket(0) if prep_operation is None, else
for op in tape.operations[bool(prep_operation) :]:
ket = apply_operation(op, ket)

return ket


def adjoint_jacobian(tape: QuantumTape):
def adjoint_jacobian(tape: QuantumTape, state=None):
"""Implements the adjoint method outlined in
`Jones and Gacon <https://arxiv.org/abs/2009.02823>`__ to differentiate an input tape.

Expand All @@ -67,6 +54,8 @@

Args:
tape (.QuantumTape): circuit that the function takes the gradient of
state (TensorLike): the final state of the circuit; if not provided,
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
the final state will be computed by executing the tape

Returns:
array or tuple[array]: the derivative of the tape with respect to trainable parameters.
Expand All @@ -77,7 +66,7 @@
wire_map = {w: i for i, w in enumerate(tape.wires)}
tape = qml.map_wires(tape, wire_map)

ket = _get_output_ket(tape)
ket = state if state is not None else _final_state(tape)[0]

n_obs = len(tape.observables)
bras = np.empty([n_obs] + [2] * len(tape.wires), dtype=np.complex128)
Expand Down Expand Up @@ -119,7 +108,7 @@
return tuple(tuple(np.array(j_) for j_ in j) for j in jac)


def adjoint_jvp(tape: QuantumTape, tangents: Tuple[Number]):
def adjoint_jvp(tape: QuantumTape, tangents: Tuple[Number], state=None):
"""The jacobian vector product used in forward mode calculation of derivatives.

Implements the adjoint method outlined in
Expand All @@ -141,6 +130,8 @@
Args:
tape (.QuantumTape): circuit that the function takes the gradient of
tangents (Tuple[Number]): gradient vector for input parameters.
state (TensorLike): the final state of the circuit; if not provided,
the final state will be computed by executing the tape

Returns:
Tuple[Number]: gradient vector for output parameters
Expand All @@ -150,7 +141,7 @@
wire_map = {w: i for i, w in enumerate(tape.wires)}
tape = qml.map_wires(tape, wire_map)

ket = _get_output_ket(tape)
ket = state if state is not None else _final_state(tape)[0]

n_obs = len(tape.observables)
bras = np.empty([n_obs] + [2] * len(tape.wires), dtype=np.complex128)
Expand Down Expand Up @@ -191,7 +182,7 @@
return tuple(np.array(t) for t in tangents_out)


def adjoint_vjp(tape: QuantumTape, cotangents: Tuple[Number]):
def adjoint_vjp(tape: QuantumTape, cotangents: Tuple[Number], state=None):
"""The vector jacobian product used in reverse-mode differentiation.

Implements the adjoint method outlined in
Expand All @@ -213,6 +204,8 @@
Args:
tape (.QuantumTape): circuit that the function takes the gradient of
cotangents (Tuple[Number]): gradient vector for output parameters
state (TensorLike): the final state of the circuit; if not provided,
the final state will be computed by executing the tape

Returns:
Tuple[Number]: gradient vector for input parameters
Expand All @@ -222,7 +215,7 @@
wire_map = {w: i for i, w in enumerate(tape.wires)}
tape = qml.map_wires(tape, wire_map)

ket = _get_output_ket(tape)
ket = state if state is not None else _final_state(tape)[0]

obs = qml.dot(cotangents, tape.observables)
bra = apply_operation(obs, ket)
Expand Down
Loading
Loading