Skip to content

Commit

Permalink
Merge pull request #443 from qiboteam/setthreads
Browse files Browse the repository at this point in the history
Set threads warnings
  • Loading branch information
scarrazza committed Jul 9, 2021
2 parents 47ad4ad + 9ad5edf commit c319d97
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
8 changes: 3 additions & 5 deletions src/qibo/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ def construct_backend(self, name):
new_backend = self.available_backends.get(name)()
if self.active_backend is not None:
new_backend.set_precision(self.active_backend.precision)
if self.active_backend.default_device:
new_backend.set_device(self.active_backend.default_device)
self.constructed_backends[name] = new_backend
return self.constructed_backends.get(name)

Expand Down Expand Up @@ -199,8 +197,9 @@ def set_device(name):
"""
if not config.ALLOW_SWITCHERS and name != K.default_device:
log.warning("Device should not be changed after allocating gates.")
K.set_device(name)
for bk in K.constructed_backends.values():
if bk.name != "numpy":
if bk.name != "numpy" and bk != K.active_backend:
bk.set_device(name)


Expand All @@ -214,8 +213,7 @@ def set_threads(nthreads):
Args:
nthreads (int): number of threads.
"""
for bk in K.constructed_backends.values():
bk.set_threads(nthreads)
K.set_threads(nthreads)


def get_threads():
Expand Down
11 changes: 9 additions & 2 deletions src/qibo/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,19 @@ def __init__(self):
self.newaxis = np.newaxis
self.oom_error = MemoryError
self.optimization = None
self.cpu_devices = ["/CPU:0"]
self.gpu_devices = []
self.default_device = self.cpu_devices[0]

def set_device(self, name):
log.warning("Numpy does not support device placement. "
"Aborting device change.")

def set_threads(self, nthreads):
log.warning("Numpy backend supports only single-thread execution. "
"Cannot change the number of threads.")
abstract.AbstractBackend.set_threads(self, nthreads)

def to_numpy(self, x):
return x

Expand Down Expand Up @@ -404,7 +412,6 @@ def __init__(self):
if "NUMBA_NUM_THREADS" in os.environ: # pragma: no cover
self.set_threads(int(os.environ.get("NUMBA_NUM_THREADS")))

# TODO: reconsider device management
self.cpu_devices = ["/CPU:0"]
self.gpu_devices = [f"/GPU:{i}" for i in range(ngpu)]
if self.gpu_devices: # pragma: no cover
Expand Down Expand Up @@ -446,7 +453,7 @@ def set_device(self, name):
self.set_engine("numba")

def set_threads(self, nthreads):
super().set_threads(nthreads)
abstract.AbstractBackend.set_threads(self, nthreads)
import numba # pylint: disable=E0401
numba.set_num_threads(nthreads)

Expand Down
11 changes: 11 additions & 0 deletions src/qibo/backends/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ def __init__(self):
def set_device(self, name):
abstract.AbstractBackend.set_device(self, name)

def set_threads(self, nthreads):
log.warning("`set_threads` is not supported by the tensorflow "
"backend. Please use tensorflow's thread setters: "
"`tf.config.threading.set_inter_op_parallelism_threads` "
"or `tf.config.threading.set_intra_op_parallelism_threads` "
"to switch the number of threads.")
abstract.AbstractBackend.set_threads(self, nthreads)

def to_numpy(self, x):
if isinstance(x, self.np.ndarray):
return x
Expand Down Expand Up @@ -216,6 +224,9 @@ def __init__(self):
if "OMP_NUM_THREADS" in os.environ: # pragma: no cover
self.set_threads(int(os.environ.get("OMP_NUM_THREADS")))

def set_threads(self, nthreads):
abstract.AbstractBackend.set_threads(self, nthreads)

def initial_state(self, nqubits, is_matrix=False):
return self.op.initial_state(nqubits, self.dtypes('DTYPECPX'),
is_matrix=is_matrix,
Expand Down
10 changes: 10 additions & 0 deletions src/qibo/tests/test_backends_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def test_set_device(backend, caplog):
bk.set_device(device)


def test_set_threads(backend, caplog):
original_threads = backends.get_threads()
bkname = backends.get_backend()
backends.set_threads(1)
if bkname == "numpy" or bkname == "tensorflow":
assert "WARNING" in caplog.text
assert backends.get_threads() == 1
backends.set_threads(original_threads)


def test_set_shot_batch_size():
import qibo
assert qibo.get_batch_size() == 2 ** 18
Expand Down

0 comments on commit c319d97

Please sign in to comment.