Skip to content

Commit

Permalink
Enabled string assignment for multiple params carrying same name
Browse files Browse the repository at this point in the history
The check is now based on the value type to infer if assignment should be done on Parameters or ParameterVectors

Removed unnecessary import from utils

Corrected string assignment

Correction part 2

Corrected test

The inplace=True argument was preventing the reuse of a parametrized waveform in the schedule, making the test fail
  • Loading branch information
arthurostrauss committed Apr 26, 2024
1 parent 4f81f74 commit 6c53868
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
28 changes: 20 additions & 8 deletions qiskit/pulse/parameter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@
from qiskit.pulse.library import SymbolicPulse, Waveform
from qiskit.pulse.schedule import Schedule, ScheduleBlock
from qiskit.pulse.transforms.alignments import AlignmentKind
from qiskit.pulse.utils import format_parameter_value, _validate_parameter_vector
from qiskit.pulse.utils import (
format_parameter_value,
_validate_parameter_vector,
_validate_parameter_value,
)


class NodeVisitor:
Expand Down Expand Up @@ -411,23 +415,31 @@ def _unroll_param_dict(
A dictionary from parameter to value.
"""
out = {}
param_name_dict = {param.name: param for param in self.parameters}
param_name_dict = {param.name: [] for param in self.parameters}
for param in self.parameters:
param_name_dict[param.name].append(param)
param_vec_dict = {
param.vector.name: param.vector
for param in self.parameters
if isinstance(param, ParameterVectorElement)
}
for name in param_vec_dict.keys():
if name in param_name_dict:
param_name_dict[name].append(param_vec_dict[name])
else:
param_name_dict[name] = [param_vec_dict[name]]

for parameter, value in parameter_binds.items():
if isinstance(parameter, ParameterVector):
_validate_parameter_vector(parameter, value)
out.update(zip(parameter, value))
elif isinstance(parameter, str):
if parameter in param_vec_dict:
param = param_vec_dict[parameter]
_validate_parameter_vector(param, value)
out.update(zip(param, value))
elif parameter in param_name_dict:
out[param_name_dict[parameter]] = value
for param in param_name_dict[parameter]:
is_vec = _validate_parameter_value(param, value)
if is_vec:
out.update(zip(param, value))
else:
out[param] = value
else:
out[parameter] = value
return out
21 changes: 20 additions & 1 deletion qiskit/pulse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np

from qiskit.circuit import ParameterVector
from qiskit.circuit import ParameterVector, Parameter
from qiskit.circuit.parameterexpression import ParameterExpression
from qiskit.pulse.exceptions import UnassignedDurationError, QiskitError, PulseError

Expand Down Expand Up @@ -132,3 +132,22 @@ def _validate_parameter_vector(parameter: ParameterVector, value):
f"Parameter vector '{parameter.name}' has length {len(parameter)},"
f" but was assigned to {len(value)} values."
)


def _validate_single_parameter(parameter: Parameter, value):
"""Validate single parameter and its value."""
if not isinstance(value, (int, float, complex, ParameterExpression)):
raise PulseError(
f"Parameter '{parameter.name}' is not assignable to {value}."
)


def _validate_parameter_value(parameter, value):
"""Validate parameter and its value."""
if isinstance(parameter, ParameterVector):
_validate_parameter_vector(parameter, value)
return True
else:
_validate_single_parameter(parameter, value)
return False

12 changes: 6 additions & 6 deletions test/python/pulse/test_parameter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,13 @@ def test_pulse_assignment_with_parameter_names(self):
block += pulse.Play(waveform2, pulse.DriveChannel(10))
block += pulse.ShiftPhase(param_vec[0], pulse.DriveChannel(10))
block += pulse.ShiftPhase(param_vec[1], pulse.DriveChannel(10))
block.assign_parameters({"amp": 0.2, "sigma": 4, "param_vec": [3.14, 1.57]}, inplace=True)
block1 = block.assign_parameters({"amp": 0.2, "sigma": 4, "param_vec": [3.14, 1.57]}, inplace=False)

self.assertEqual(block.blocks[0].pulse.amp, 0.2)
self.assertEqual(block.blocks[0].pulse.sigma, 4.0)
self.assertEqual(block.blocks[1].pulse.amp, 0.2)
self.assertEqual(block.blocks[2].phase, 3.14)
self.assertEqual(block.blocks[3].phase, 1.57)
self.assertEqual(block1.blocks[0].pulse.amp, 0.2)
self.assertEqual(block1.blocks[0].pulse.sigma, 4.0)
self.assertEqual(block1.blocks[1].pulse.amp, 0.2)
self.assertEqual(block1.blocks[2].phase, 3.14)
self.assertEqual(block1.blocks[3].phase, 1.57)

sched = pulse.Schedule()
sched += pulse.Play(waveform, pulse.DriveChannel(10))
Expand Down

0 comments on commit 6c53868

Please sign in to comment.