Skip to content

Commit

Permalink
Fix pickle handling for SingletonGate class (#10871)
Browse files Browse the repository at this point in the history
* Fix pickle handling for SingletonGate class

This commit fixes an oversight in #10314 where the handling of pickle
wasn't done correctly. Because the SingletonGate class defines __new__
and based on the parameters to the gate.__class__() call determines
whether we get a new mutable copy or a shared singleton immutable
instance we need special handling in pickle. By default pickle will call
__new__() without any arguments and then rely on __setstate__ to update
the state in the new object. This works fine if the original instance
was a singleton but in the case of mutable copies this will create a
singleton object instead of a mutable copy. To fix this a __reduce__
method is added to ensure arguments get passed to __new__ forcing a
mutable object to be created in deserialization. Then a __setstate__
method is defined to correctly update the mutable object post creation.

* Use __getnewargs_ex__ insetad of __reduce__ & __setstate__

This commit pivots the pickle interface methods used to implement
__getnewargs_ex__ instead of the combination of __reduce__ and
__setstate__. Realistically, all we need to do here is pass that
we have mutable arguments to new to trigger it to create a separate
object, the rest of pickle was working correctly. This makes the
interface being used for pickle a lot clearer.

* Improve assertion in immutable pickle test
  • Loading branch information
mtreinish committed Sep 28, 2023
1 parent e5420d2 commit 1f37e23
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
5 changes: 5 additions & 0 deletions qiskit/circuit/singleton_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def __init__(self, *args, _condition=None, **kwargs):
super().__init__(*args, **kwargs)
self._condition = _condition

def __getnewargs_ex__(self):
if not self.mutable:
return ((), {})
return ((self.label, self._condition, self.duration, self.unit), {})

def c_if(self, classical, val):
if not isinstance(classical, (ClassicalRegister, Clbit)):
raise CircuitError("c_if must be used with a classical register or classical bit")
Expand Down
26 changes: 26 additions & 0 deletions test/python/circuit/test_singleton_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""

import copy
import io
import pickle

from qiskit.circuit.library import HGate, SXGate
from qiskit.circuit import Clbit, QuantumCircuit, QuantumRegister, ClassicalRegister
Expand Down Expand Up @@ -251,3 +253,27 @@ def test_positional_label(self):
label_gate = SXGate("I am a little label")
self.assertIsNot(gate, label_gate)
self.assertEqual(label_gate.label, "I am a little label")

def test_immutable_pickle(self):
gate = SXGate()
self.assertFalse(gate.mutable)
with io.BytesIO() as fd:
pickle.dump(gate, fd)
fd.seek(0)
copied = pickle.load(fd)
self.assertFalse(copied.mutable)
self.assertIs(copied, gate)

def test_mutable_pickle(self):
gate = SXGate()
clbit = Clbit()
condition_gate = gate.c_if(clbit, 0)
self.assertIsNot(gate, condition_gate)
self.assertEqual(condition_gate.condition, (clbit, 0))
self.assertTrue(condition_gate.mutable)
with io.BytesIO() as fd:
pickle.dump(condition_gate, fd)
fd.seek(0)
copied = pickle.load(fd)
self.assertEqual(copied, condition_gate)
self.assertTrue(copied.mutable)
14 changes: 14 additions & 0 deletions test/python/compiler/test_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,20 @@ def run(self, dag):
added_cal = qc_test.calibrations["sx"][((0,), ())]
self.assertEqual(added_cal, ref_cal)

@data(0, 1, 2, 3)
def test_parallel_singleton_conditional_gate(self, opt_level):
"""Test that singleton mutable instance doesn't lose state in parallel."""
backend = FakeNairobiV2()
circ = QuantumCircuit(2, 1)
circ.h(0)
circ.measure(0, circ.clbits[0])
circ.z(1).c_if(circ.clbits[0], 1)
res = transpile(
[circ, circ], backend, optimization_level=opt_level, seed_transpiler=123456769
)
self.assertTrue(res[0].data[-1].operation.mutable)
self.assertEqual(res[0].data[-1].operation.condition, (res[0].clbits[0], 1))

@data(0, 1, 2, 3)
def test_backendv2_and_basis_gates(self, opt_level):
"""Test transpile() with BackendV2 and basis_gates set."""
Expand Down

0 comments on commit 1f37e23

Please sign in to comment.