Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update interfaces module to use bind_new_parameters #4345

Merged
merged 16 commits into from
Jul 14, 2023
Merged
3 changes: 1 addition & 2 deletions pennylane/interfaces/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@

def _set_copy_and_unwrap_tape(t, a, unwrap=True):
"""Copy a given tape with operations and set parameters"""
tc = t.copy(copy_operations=True)
tc.set_parameters(a)
tc = qml.tape.qscript.bind_new_parameters_tape(t, a, t.trainable_params)
return convert_to_numpy_parameters(tc) if unwrap else tc


Expand Down
13 changes: 2 additions & 11 deletions pennylane/interfaces/jax_jit.py
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from pennylane.interfaces import InterfaceUnsupportedError
from pennylane.interfaces.jax import _raise_vector_valued_fwd
from pennylane.measurements import ProbabilityMP
from pennylane.transforms import convert_to_numpy_parameters

from .jax import set_parameters_on_copy_and_unwrap

dtype = jnp.float64

Expand Down Expand Up @@ -143,16 +144,6 @@ def _execute_legacy(
): # pylint: disable=dangerous-default-value,unused-argument
total_params = np.sum([len(p) for p in params])

# Copy a given tape with operations and set parameters
def _set_copy_and_unwrap_tape(t, a, unwrap=True):
tc = t.copy(copy_operations=True)
tc.set_parameters(a)
return convert_to_numpy_parameters(tc) if unwrap else tc

def set_parameters_on_copy_and_unwrap(tapes, params, unwrap=True):
"""Copy a set of tapes with operations and set parameters"""
return tuple(_set_copy_and_unwrap_tape(t, a, unwrap=unwrap) for t, a in zip(tapes, params))

@jax.custom_vjp
def wrapped_exec(params):
result_shapes_dtypes = _extract_shape_dtype_structs(tapes, device)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/interfaces/jax_jit_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
from pennylane.interfaces.jax_jit import _numeric_type_to_dtype
from pennylane.transforms import convert_to_numpy_parameters


dtype = jnp.float64
Zero = jax.custom_derivatives.SymbolicZero


def _set_copy_and_unwrap_tape(t, a, unwrap=True):
"""Copy a given tape with operations and set parameters"""
tc = t.copy(copy_operations=True)
tc.set_parameters(a, trainable_only=False)
tc = qml.tape.qscript.bind_new_parameters_tape(t, a, list(range(len(a))))
return convert_to_numpy_parameters(tc) if unwrap else tc


Expand Down
3 changes: 1 addition & 2 deletions pennylane/interfaces/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@

def _set_copy_and_unwrap_tape(t, a):
"""Copy a given tape with operations and set parameters"""
tc = t.copy(copy_operations=True)
tc.set_parameters(a, trainable_only=False)
tc = qml.tape.qscript.bind_new_parameters_tape(t, a, list(range(len(a))))
return convert_to_numpy_parameters(tc)


Expand Down
14 changes: 1 addition & 13 deletions pennylane/interfaces/tensorflow_autograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,17 @@

import pennylane as qml
from pennylane.measurements import SampleMP, StateMP
from pennylane.transforms import convert_to_numpy_parameters

from .tensorflow import (
_compute_vjp,
_compute_vjp_legacy,
_jac_restructured,
_res_restructured,
_to_tensors,
set_parameters_on_copy_and_unwrap,
)


def _set_copy_and_unwrap_tape(t, a):
"""Copy a given tape with operations and set parameters"""
tc = t.copy(copy_operations=True)
tc.set_parameters(a, trainable_only=False)
return convert_to_numpy_parameters(tc)


def set_parameters_on_copy_and_unwrap(tapes, params):
"""Copy a set of tapes with operations and set parameters"""
return tuple(_set_copy_and_unwrap_tape(t, a) for t, a in zip(tapes, params))


def _flatten_nested_list(x):
"""
Recursively flatten the list
Expand Down
17 changes: 16 additions & 1 deletion pennylane/pulse/parametrized_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
This file contains the ``ParametrizedEvolution`` operator.
"""

from typing import List, Union
from typing import List, Union, Sequence
import warnings

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.typing import TensorLike
from pennylane.ops import functions

from .parametrized_hamiltonian import ParametrizedHamiltonian
from .hardware_hamiltonian import HardwareHamiltonian
Expand Down Expand Up @@ -502,3 +504,16 @@ def fun(y, t):
elif not self.hyperparameters["return_intermediate"]:
mat = mat[-1]
return qml.math.expand_matrix(mat, wires=self.wires, wire_order=wire_order)


@functions.bind_new_parameters.register
def _bind_new_parameters_parametrized_evol(op: ParametrizedEvolution, params: Sequence[TensorLike]):
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
return ParametrizedEvolution(
op.H,
params=params,
t=op.t,
return_intermediate=op.hyperparameters["return_intermediate"],
complementary=op.hyperparameters["complementary"],
dense=op.dense,
**op.odeint_kwargs,
)
74 changes: 74 additions & 0 deletions pennylane/tape/qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
VarianceMP,
Shots,
)
from pennylane.typing import TensorLike
from pennylane.operation import Observable, Operator, Operation
from pennylane.queuing import AnnotatedQueue, process_queue

Expand Down Expand Up @@ -1410,3 +1411,76 @@ def wrapper(*args, **kwargs):
return qscript

return wrapper


def bind_new_parameters_tape(
tape: QuantumScript, params: Sequence[TensorLike], indices: Sequence[int]
):
"""Create a new tape with updated parameters.

This function takes a :class:`~.tape.QuantumScript` as input, and returns
a new ``QuantumScript`` containing the new parameters at the provided indices,
with the parameters at all other indices remaining the same.

Args:
tape (.tape.QuantumScript): Tape to update
params (Sequence[TensorLike]): New parameters to create the tape with. This
must have the same length as ``indices``.
indices (Sequence[int]): The parameter indices to update with the given parameters.
The index of a parameter is defined as its index in ``tape.get_parameters()``.

Returns:
.tape.QuantumScript: New tape with updated parameters
"""
# pylint: disable=no-member

if len(params) != len(indices):
raise ValueError("Number of provided parameters does not match number of indices")

new_ops = []
idx = 0
p_idx = 0

for op in tape.circuit:
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(op, Operator):
data = op.data
elif op.obs is not None:
data = op.obs.data
else:
data = ()

# determine if any parameters of the operator need to be rebinded
if any(i + idx in indices for i in range(len(data))):
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
new_params = []
for i, d in enumerate(data):
if i + idx in indices:
new_params.append(params[p_idx])
p_idx += 1
else:
new_params.append(d)

if isinstance(op, Operator):
new_op = qml.ops.functions.bind_new_parameters(op, new_params)
else:
new_obs = qml.ops.functions.bind_new_parameters(op.obs, new_params)
new_op = op.__class__(obs=new_obs)

new_ops.append(new_op)
else:
# no need to change the operator
new_ops.append(op)

idx += len(data)

new_prep = new_ops[: len(tape._prep)]
new_operations = new_ops[len(tape._prep) : len(tape.operations)]
new_measurements = new_ops[len(tape.operations) :]

new_tape = qml.tape.QuantumScript(new_operations, new_measurements, new_prep, shots=tape.shots)
new_tape.trainable_params = tape.trainable_params
new_tape._qfunc_output = tape._qfunc_output

return new_tape


QuantumScript.bind_new_parameters = bind_new_parameters_tape
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion tests/gradients/core/test_pulse_generator_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,9 @@ def ansatz(params):
assert dev.num_executions == 1 + 12 # one forward execution, dim(DLA)=6
grad_backprop = jax.grad(qnode_backprop)(params)

assert all(qml.math.allclose(r, e) for r, e in zip(grad_pulse_grad, grad_backprop))
assert all(
qml.math.allclose(r, e, atol=1e-7) for r, e in zip(grad_pulse_grad, grad_backprop)
)

@pytest.mark.parametrize("argnums", [[0, 1], 0, 1])
def test_simple_qnode_expval_multiple_params(self, argnums):
Expand Down
26 changes: 1 addition & 25 deletions tests/tape/test_qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_update_observables(self):
)
def test_update_batch_size(self, x, rot, exp_batch_size):
"""Test that the batch size is correctly inferred from all operation's
batch_size, when creating and when using `set_parameters`."""
batch_size when creating a QuantumScript."""

obs = [qml.RX(x, wires=0), qml.Rot(*rot, wires=1)]
m = [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(1))]
Expand Down Expand Up @@ -502,18 +502,6 @@ def test_shallow_copy(self):
# check that the output dim is identical
assert qs.output_dim == copied_qs.output_dim

# since the copy is shallow, mutating the parameters
# on one tape will affect the parameters on another tape
new_params = [np.array([0, 0]), 0.2]
qs.set_parameters(new_params)

# check that they are the same objects in memory
for i, j in zip(qs.get_parameters(), new_params):
assert i is j

for i, j in zip(copied_qs.get_parameters(), new_params):
assert i is j

# pylint: disable=unnecessary-lambda
@pytest.mark.parametrize(
"copy_fn", [lambda tape: tape.copy(copy_operations=True), lambda tape: copy.copy(tape)]
Expand Down Expand Up @@ -550,18 +538,6 @@ def test_shallow_copy_with_operations(self, copy_fn):
# check that the output dim is identical
assert qs.output_dim == copied_qs.output_dim

# Since they have unique operations, mutating the parameters
# on one script will *not* affect the parameters on another script
new_params = [np.array([0, 0]), 0.2]
qs.set_parameters(new_params)

for i, j in zip(qs.get_parameters(), new_params):
assert i is j

for i, j in zip(copied_qs.get_parameters(), new_params):
assert not np.all(i == j)
assert i is not j

def test_deep_copy(self):
"""Test that deep copying a tape works, and copies all constituent data except parameters"""
prep = [qml.BasisState(np.array([1, 0]), wires=(0, 1))]
Expand Down
Loading
Loading