diff --git a/docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb b/docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb index b98f90ea1..4e94491e0 100644 --- a/docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb +++ b/docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb @@ -456,6 +456,7 @@ " index_generator=index_generator,\n", " dataset=TorchDataset,\n", " seed=seed,\n", + " use_gpu=False,\n", " )" ] }, @@ -574,6 +575,8 @@ "As a last step before launching our experiment, we need to specify the third parties dependencies required to run it.\n", "The [Dependency](https://docs.substra.org/en/stable/substrafl_doc/api/dependency.html) object is instantiated in order to install the right libraries in\n", "the Python environment of each organization.\n", + "\n", + "The CPU torch version is installed here to have a `Dependency` object as light as possible as we don't use GPUs (`use_gpu` set to `False`). Remove the `--extra-index-url` to install the cuda torch version.\n", "\n" ] }, @@ -587,7 +590,7 @@ "source": [ "from substrafl.dependency import Dependency\n", "\n", - "dependencies = Dependency(pypi_dependencies=[\"numpy==1.24.3\", \"torch==2.0.1\", \"scikit-learn==1.3.1\"])" + "dependencies = Dependency(pypi_dependencies=[\"numpy==1.24.3\", \"scikit-learn==1.3.1\", \"torch==2.0.1\", \"--extra-index-url https://download.pytorch.org/whl/cpu\"])" ] }, { diff --git a/docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb b/docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb index 9f83ff413..8557b6c29 100644 --- a/docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb +++ b/docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb @@ -819,6 +819,7 @@ " index_generator=index_generator,\n", " dataset=TorchDataset,\n", " seed=seed,\n", + " use_gpu=False,\n", " )\n", "\n", "\n", @@ -923,7 +924,7 @@ "source": [ "from substrafl.dependency import Dependency\n", "\n", - "dependencies = Dependency(pypi_dependencies=[\"numpy==1.24.3\", \"torch==2.0.1\", \"scikit-learn==1.3.1\"])" + "dependencies = Dependency(pypi_dependencies=[\"numpy==1.24.3\", \"scikit-learn==1.3.1\", \"torch==2.0.1\", \"--extra-index-url https://download.pytorch.org/whl/cpu\"])" ] }, {