diff --git a/src/qibo/backends/__init__.py b/src/qibo/backends/__init__.py index 8252705199..dcd05483db 100644 --- a/src/qibo/backends/__init__.py +++ b/src/qibo/backends/__init__.py @@ -105,8 +105,6 @@ def construct_backend(self, name): new_backend = self.available_backends.get(name)() if self.active_backend is not None: new_backend.set_precision(self.active_backend.precision) - if self.active_backend.default_device: - new_backend.set_device(self.active_backend.default_device) self.constructed_backends[name] = new_backend return self.constructed_backends.get(name) @@ -199,8 +197,9 @@ def set_device(name): """ if not config.ALLOW_SWITCHERS and name != K.default_device: log.warning("Device should not be changed after allocating gates.") + K.set_device(name) for bk in K.constructed_backends.values(): - if bk.name != "numpy": + if bk.name != "numpy" and bk != K.active_backend: bk.set_device(name) @@ -214,8 +213,7 @@ def set_threads(nthreads): Args: nthreads (int): number of threads. """ - for bk in K.constructed_backends.values(): - bk.set_threads(nthreads) + K.set_threads(nthreads) def get_threads(): diff --git a/src/qibo/backends/numpy.py b/src/qibo/backends/numpy.py index 8394e69081..24f707ff53 100644 --- a/src/qibo/backends/numpy.py +++ b/src/qibo/backends/numpy.py @@ -25,11 +25,19 @@ def __init__(self): self.newaxis = np.newaxis self.oom_error = MemoryError self.optimization = None + self.cpu_devices = ["/CPU:0"] + self.gpu_devices = [] + self.default_device = self.cpu_devices[0] def set_device(self, name): log.warning("Numpy does not support device placement. " "Aborting device change.") + def set_threads(self, nthreads): + log.warning("Numpy backend supports only single-thread execution. " + "Cannot change the number of threads.") + abstract.AbstractBackend.set_threads(self, nthreads) + def to_numpy(self, x): return x @@ -404,7 +412,6 @@ def __init__(self): if "NUMBA_NUM_THREADS" in os.environ: # pragma: no cover self.set_threads(int(os.environ.get("NUMBA_NUM_THREADS"))) - # TODO: reconsider device management self.cpu_devices = ["/CPU:0"] self.gpu_devices = [f"/GPU:{i}" for i in range(ngpu)] if self.gpu_devices: # pragma: no cover @@ -446,7 +453,7 @@ def set_device(self, name): self.set_engine("numba") def set_threads(self, nthreads): - super().set_threads(nthreads) + abstract.AbstractBackend.set_threads(self, nthreads) import numba # pylint: disable=E0401 numba.set_num_threads(nthreads) diff --git a/src/qibo/backends/tensorflow.py b/src/qibo/backends/tensorflow.py index 5f93d53db5..f74062800f 100644 --- a/src/qibo/backends/tensorflow.py +++ b/src/qibo/backends/tensorflow.py @@ -49,6 +49,14 @@ def __init__(self): def set_device(self, name): abstract.AbstractBackend.set_device(self, name) + def set_threads(self, nthreads): + log.warning("`set_threads` is not supported by the tensorflow " + "backend. Please use tensorflow's thread setters: " + "`tf.config.threading.set_inter_op_parallelism_threads` " + "or `tf.config.threading.set_intra_op_parallelism_threads` " + "to switch the number of threads.") + abstract.AbstractBackend.set_threads(self, nthreads) + def to_numpy(self, x): if isinstance(x, self.np.ndarray): return x @@ -216,6 +224,9 @@ def __init__(self): if "OMP_NUM_THREADS" in os.environ: # pragma: no cover self.set_threads(int(os.environ.get("OMP_NUM_THREADS"))) + def set_threads(self, nthreads): + abstract.AbstractBackend.set_threads(self, nthreads) + def initial_state(self, nqubits, is_matrix=False): return self.op.initial_state(nqubits, self.dtypes('DTYPECPX'), is_matrix=is_matrix, diff --git a/src/qibo/tests/test_backends_init.py b/src/qibo/tests/test_backends_init.py index 1aa14a9acf..c2f912b51f 100644 --- a/src/qibo/tests/test_backends_init.py +++ b/src/qibo/tests/test_backends_init.py @@ -103,6 +103,16 @@ def test_set_device(backend, caplog): bk.set_device(device) +def test_set_threads(backend, caplog): + original_threads = backends.get_threads() + bkname = backends.get_backend() + backends.set_threads(1) + if bkname == "numpy" or bkname == "tensorflow": + assert "WARNING" in caplog.text + assert backends.get_threads() == 1 + backends.set_threads(original_threads) + + def test_set_shot_batch_size(): import qibo assert qibo.get_batch_size() == 2 ** 18