Skip to content

Commit

Permalink
map_wires support for a batch of tapes (#6295)
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro authored Sep 25, 2024
1 parent 912f188 commit 215ab98
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 13 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@

<h3>Bug fixes 🐛</h3>

* `qml.map_wires` can now be applied to a batch of tapes.
[(#6295)](https://github.com/PennyLaneAI/pennylane/pull/6295)

* Fix float-to-complex casting in various places across PennyLane.
[(#6260)](https://github.com/PennyLaneAI/pennylane/pull/6260)
[(#6268)](https://github.com/PennyLaneAI/pennylane/pull/6268)
Expand Down
45 changes: 33 additions & 12 deletions pennylane/ops/functions/map_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
This module contains the qml.map_wires function.
"""
from collections.abc import Callable
from functools import partial
from typing import Union
from typing import Union, overload

import pennylane as qml
from pennylane import transform
Expand All @@ -28,8 +27,32 @@
from pennylane.workflow import QNode


@overload
def map_wires(
input: Union[Operator, MeasurementProcess, QuantumScript, QNode, Callable],
input: Operator, wire_map: dict, queue: bool = False, replace: bool = False
) -> Operator: ...
@overload
def map_wires(
input: MeasurementProcess, wire_map: dict, queue: bool = False, replace: bool = False
) -> MeasurementProcess: ...
@overload
def map_wires(
input: QuantumScript, wire_map: dict, queue: bool = False, replace: bool = False
) -> tuple[QuantumScriptBatch, PostprocessingFn]: ...
@overload
def map_wires(
input: QNode, wire_map: dict, queue: bool = False, replace: bool = False
) -> QNode: ...
@overload
def map_wires(
input: Callable, wire_map: dict, queue: bool = False, replace: bool = False
) -> Callable: ...
@overload
def map_wires(
input: QuantumScriptBatch, wire_map: dict, queue: bool = False, replace: bool = False
) -> tuple[QuantumScriptBatch, PostprocessingFn]: ...
def map_wires(
input: Union[Operator, MeasurementProcess, QuantumScript, QNode, Callable, QuantumScriptBatch],
wire_map: dict,
queue=False,
replace=False,
Expand Down Expand Up @@ -101,13 +124,15 @@ def map_wires(
qml.apply(new_op)
return new_op
return input.map_wires(wire_map=wire_map)
if isinstance(input, (QuantumScript, QNode)) or callable(input):
return _map_wires_transform(input, wire_map=wire_map, queue=queue)
return _map_wires_transform(input, wire_map=wire_map, queue=queue)

raise ValueError(f"Cannot map wires of object {input} of type {type(input)}.")

def processing_fn(res):
"""An empty postprocessing function that leaves the results unchanged."""
return res[0]

@partial(transform)

@transform
def _map_wires_transform(
tape: QuantumScript, wire_map=None, queue=False
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
Expand All @@ -125,8 +150,4 @@ def _map_wires_transform(
ops=ops, measurements=measurements, shots=tape.shots, trainable_params=tape.trainable_params
)

def processing_fn(res):
"""Defines how matrix works if applied to a tape containing multiple operations."""
return res[0]

return [out], processing_fn
return (out,), processing_fn
15 changes: 14 additions & 1 deletion tests/ops/functions/test_map_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_map_wires_with_queuing_and_with_replacing(self):

def test_map_wires_unsupported_object_raises_error(self):
"""Test that an error is raised when trying to map the wires of an unsupported object."""
with pytest.raises(ValueError, match="Cannot map wires of object"):
with pytest.raises(qml.transforms.core.TransformError, match="Decorating a QNode with"):
qml.map_wires("unsupported type", wire_map=wire_map)


Expand Down Expand Up @@ -163,6 +163,19 @@ def test_map_wires_nested_tape(self):
assert nested_m_tape.operations == [qml.PauliY(4), qml.Hadamard(3), qml.PauliY(1)]
assert len(nested_m_tape.measurements) == 0

def test_map_wires_batch(self):
"""Test that map_wires can be applied to a batch of tapes."""

t1 = qml.tape.QuantumScript([qml.X(0)], [qml.expval(qml.Z(0))])
t2 = qml.tape.QuantumScript([qml.Y(1)], [qml.probs(wires=1)])

batch, _ = qml.map_wires((t1, t2), {0: "a", 1: "b"})

expected1 = qml.tape.QuantumScript([qml.X("a")], [qml.expval(qml.Z("a"))])
expected2 = qml.tape.QuantumScript([qml.Y("b")], [qml.probs(wires="b")])
qml.assert_equal(batch[0], expected1)
qml.assert_equal(batch[1], expected2)


class TestMapWiresQNodes:
"""Tests for the qml.map_wires method used with qnodes."""
Expand Down

0 comments on commit 215ab98

Please sign in to comment.