diff --git a/changes/436.changed b/changes/436.changed new file mode 100644 index 00000000..25a6354f --- /dev/null +++ b/changes/436.changed @@ -0,0 +1 @@ +Apply `use_gpu` to `diasble_gpu` renaming on SubstraFL in all `TorchAlgo` \ No newline at end of file diff --git a/docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb b/docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb index 4fca9ef4..de89b993 100644 --- a/docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb +++ b/docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb @@ -439,7 +439,7 @@ " index_generator=index_generator,\n", " dataset=TorchDataset,\n", " seed=seed,\n", - " use_gpu=False,\n", + " disable_gpu=True,\n", " )" ] }, @@ -558,7 +558,7 @@ "The [Dependency](https://docs.substra.org/en/stable/substrafl_doc/api/dependency.html) object is instantiated in order to install the right libraries in\n", "the Python environment of each organization.\n", "\n", - "The CPU torch version is installed here to have a `Dependency` object as light as possible as we don't use GPUs (`use_gpu` set to `False`). Remove the `--extra-index-url` to install the cuda torch version." + "The CPU torch version is installed here to have a `Dependency` object as light as possible as we don't use GPUs (`disable_gpu` set to `True`). Remove the `--extra-index-url` to install the cuda torch version." ] }, { diff --git a/docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb b/docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb index 078ad88b..f367fd68 100644 --- a/docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb +++ b/docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb @@ -1,1109 +1,1109 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Creating Torch Cyclic strategy on MNIST dataset\n", - "\n", - "This example illustrates an advanced usage of SubstraFL and proposes to implement a new Federated Learning strategy,\n", - "called **Cyclic Strategy**, using the SubstraFL base classes.\n", - "This example runs on the [MNIST Dataset of handwritten digits](http://yann.lecun.com/exdb/mnist/) using PyTorch.\n", - "In this example, we work on 28x28 pixel sized grayscale images. This is a classification problem\n", - "aiming to recognize the number written on each image.\n", - "\n", - "The **Cyclic Strategy** consists in training locally a model on different organizations (or centers) sequentially (one after the other). We\n", - "consider a round of this strategy to be a full cycle of local trainings.\n", - "\n", - "This example shows an implementation of the CyclicTorchAlgo using\n", - "[TorchAlgo](https://docs.substra.org/en/stable/substrafl_doc/api/algorithms.html#torch-algorithms) as base class, and the CyclicStrategy implementation using\n", - "[Strategy](https://docs.substra.org/en/stable/substrafl_doc/api/strategies.html) as base class.\n", - "\n", - "This example does not use a deployed platform of Substra and runs in local mode.\n", - "\n", - "To run this example, you need to download and unzip the assets needed to run it in the same directory as used this example:\n", - "\n", - "- [assets required to run this example](../../../tmp/torch_cyclic_assets.zip)\n", - "\n", - "Please ensure to have all the libraries installed. A *requirements.txt* file is included in the zip file, where you can run the command `pip install -r requirements.txt` to install them.\n", - "\n", - "**Substra** and **SubstraFL** should already be installed. If not, follow the instructions described [here](https://docs.substra.org/en/stable/substrafl_doc/substrafl_overview.html#installation).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup\n", - "\n", - "This example runs with three organizations. Two organizations provide datasets, while a third\n", - "one provides the algorithm.\n", - "\n", - "In the following code cell, we define the different organizations needed for our FL experiment.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from substra import Client\n", - "\n", - "N_CLIENTS = 3\n", - "\n", - "client_0 = Client(client_name=\"org-1\")\n", - "client_1 = Client(client_name=\"org-2\")\n", - "client_2 = Client(client_name=\"org-3\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Every computation will run in `subprocess` mode, where everything runs locally in Python\n", - "subprocesses.\n", - "Other backend_types are:\n", - "\n", - "- `docker` mode where computations run locally in docker containers\n", - "- `remote` mode where computations run remotely (you need to have a deployed platform for that)\n", - "\n", - "To run in remote mode, use the following syntax:\n", - "\n", - "`client_remote = Client(backend_type=\"remote\", url=\"MY_BACKEND_URL\", username=\"my-username\", password=\"my-password\")`\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# Create a dictionary to easily access each client from its human-friendly id\n", - "clients = {\n", - " client_0.organization_info().organization_id: client_0,\n", - " client_1.organization_info().organization_id: client_1,\n", - " client_2.organization_info().organization_id: client_2,\n", - "}\n", - "\n", - "# Store organization IDs\n", - "ORGS_ID = list(clients)\n", - "# Algo provider is defined as the first organization.\n", - "ALGO_ORG_ID = ORGS_ID[0]\n", - "# All organizations provide data in this cyclic setup.\n", - "DATA_PROVIDER_ORGS_ID = ORGS_ID" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data and metrics\n", - "\n", - "### Data preparation\n", - "\n", - "This section downloads (if needed) the **MNIST dataset** using the [torchvision library](https://pytorch.org/vision/stable/index.html).\n", - "It extracts the images from the raw files and locally creates a folder for each\n", - "organization.\n", - "\n", - "Each organization will have access to half the training data and half the test data (which\n", - "corresponds to **30,000**\n", - "images for training and **5,000** for testing each).\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "import pathlib\n", - "from torch_cyclic_assets.dataset.cyclic_mnist_dataset import setup_mnist\n", - "\n", - "\n", - "# Create the temporary directory for generated data\n", - "(pathlib.Path.cwd() / \"tmp\").mkdir(exist_ok=True)\n", - "data_path = pathlib.Path.cwd() / \"tmp\" / \"data_mnist\"\n", - "\n", - "setup_mnist(data_path, len(DATA_PROVIDER_ORGS_ID))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset registration\n", - "\n", - "A [Dataset](https://docs.substra.org/en/stable/documentation/concepts.html#dataset) is composed of an **opener**, which is a Python script that can load\n", - "the data from the files in memory and a description markdown file.\n", - "The [Dataset](https://docs.substra.org/en/stable/documentation/concepts.html#dataset) object itself does not contain the data. The proper asset that contains the\n", - "data is the **datasample asset**.\n", - "\n", - "A **datasample** contains a local path to the data. A datasample can be linked to a dataset in order to add data to a\n", - "dataset.\n", - "\n", - "Data privacy is a key concept for Federated Learning experiments. That is why we set\n", - "[Permissions](https://docs.substra.org/en/stable/documentation/concepts.html#permissions) for [Assets](https://docs.substra.org/en/stable/documentation/concepts.html#assets) to determine how each organization\n", - "can access a specific asset.\n", - "You can read more about these concepts in the [User Guide](https://docs.substra.org/en/stable/documentation/concepts.htm).\n", - "\n", - "Note that metadata such as the assets' creation date and the asset owner are visible to all the organizations of a\n", - "network.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from substra.sdk.schemas import DatasetSpec\n", - "from substra.sdk.schemas import Permissions\n", - "from substra.sdk.schemas import DataSampleSpec\n", - "\n", - "assets_directory = pathlib.Path.cwd() / \"torch_cyclic_assets\"\n", - "dataset_keys = {}\n", - "train_datasample_keys = {}\n", - "test_datasample_keys = {}\n", - "\n", - "for i, org_id in enumerate(DATA_PROVIDER_ORGS_ID):\n", - " client = clients[org_id]\n", - "\n", - " permissions_dataset = Permissions(public=False, authorized_ids=[ALGO_ORG_ID])\n", - "\n", - " # DatasetSpec is the specification of a dataset. It makes sure every field\n", - " # is well-defined, and that our dataset is ready to be registered.\n", - " # The real dataset object is created in the add_dataset method.\n", - "\n", - " dataset = DatasetSpec(\n", - " name=\"MNIST\",\n", - " data_opener=assets_directory / \"dataset\" / \"cyclic_mnist_opener.py\",\n", - " description=assets_directory / \"dataset\" / \"description.md\",\n", - " permissions=permissions_dataset,\n", - " logs_permission=permissions_dataset,\n", - " )\n", - " dataset_keys[org_id] = client.add_dataset(dataset)\n", - " assert dataset_keys[org_id], \"Missing dataset key\"\n", - "\n", - " # Add the training data on each organization.\n", - " data_sample = DataSampleSpec(\n", - " data_manager_keys=[dataset_keys[org_id]],\n", - " path=data_path / f\"org_{i+1}\" / \"train\",\n", - " )\n", - " train_datasample_keys[org_id] = client.add_data_sample(data_sample)\n", - "\n", - " # Add the testing data on each organization.\n", - " data_sample = DataSampleSpec(\n", - " data_manager_keys=[dataset_keys[org_id]],\n", - " path=data_path / f\"org_{i+1}\" / \"test\",\n", - " )\n", - " test_datasample_keys[org_id] = client.add_data_sample(data_sample)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Metrics definition\n", - "\n", - "A metric is a function used to evaluate the performance of your model.\n", - "\n", - "To add a metric, you need to define a function that computes and returns a performance\n", - "from the data (as returned by the opener) and the predictions of the model.\n", - "\n", - "When using a Torch SubstraFL algorithm, the predictions are returned by the `predict` function.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from sklearn.metrics import accuracy_score\n", - "from sklearn.metrics import roc_auc_score\n", - "import numpy as np\n", - "\n", - "\n", - "def accuracy(data_from_opener, predictions):\n", - " y_true = data_from_opener[\"labels\"]\n", - "\n", - " return accuracy_score(y_true, np.argmax(predictions, axis=1))\n", - "\n", - "\n", - "def roc_auc(data_from_opener, predictions):\n", - " y_true = data_from_opener[\"labels\"]\n", - "\n", - " n_class = np.max(y_true) + 1\n", - " y_true_one_hot = np.eye(n_class)[y_true]\n", - "\n", - " return roc_auc_score(y_true_one_hot, predictions)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Machine learning components definition\n", - "\n", - "This section uses the PyTorch based SubstraFL API to simplify the definition of machine learning components.\n", - "However, SubstraFL is compatible with any machine learning framework.\n", - "\n", - "\n", - "In this section, you will:\n", - "\n", - "- Register a model and its dependencies\n", - "- Create a federated learning strategy\n", - "- Specify the training and aggregation nodes\n", - "- Specify the test nodes\n", - "- Actually run the computations\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Model definition\n", - "\n", - "We choose to use a classic torch CNN as the model to train. The model architecture is defined by the user\n", - "independently of SubstraFL.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "import torch\n", - "from torch import nn\n", - "import torch.nn.functional as F\n", - "\n", - "seed = 42\n", - "torch.manual_seed(seed)\n", - "\n", - "\n", - "class CNN(nn.Module):\n", - " def __init__(self):\n", - " super(CNN, self).__init__()\n", - " self.conv1 = nn.Conv2d(1, 32, kernel_size=5)\n", - " self.conv2 = nn.Conv2d(32, 32, kernel_size=5)\n", - " self.conv3 = nn.Conv2d(32, 64, kernel_size=5)\n", - " self.fc1 = nn.Linear(3 * 3 * 64, 256)\n", - " self.fc2 = nn.Linear(256, 10)\n", - "\n", - " def forward(self, x, eval=False):\n", - " x = F.relu(self.conv1(x))\n", - " x = F.relu(F.max_pool2d(self.conv2(x), 2))\n", - " x = F.dropout(x, p=0.5, training=not eval)\n", - " x = F.relu(F.max_pool2d(self.conv3(x), 2))\n", - " x = F.dropout(x, p=0.5, training=not eval)\n", - " x = x.view(-1, 3 * 3 * 64)\n", - " x = F.relu(self.fc1(x))\n", - " x = F.dropout(x, p=0.5, training=not eval)\n", - " x = self.fc2(x)\n", - " return F.log_softmax(x, dim=1)\n", - "\n", - "\n", - "model = CNN()\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", - "criterion = torch.nn.CrossEntropyLoss()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Specifying on how much data to train\n", - "\n", - "To specify on how much data to train at each round, we use the `index_generator` object.\n", - "We specify the batch size and the number of batches (named `num_updates`) to consider for each round.\n", - "See [Index Generator](https://docs.substra.org/en/stable/substrafl_doc/substrafl_overview.html#index-generator) for more details.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from substrafl.index_generator import NpIndexGenerator\n", - "\n", - "# Number of model updates between each FL strategy aggregation.\n", - "NUM_UPDATES = 100\n", - "\n", - "# Number of samples per update.\n", - "BATCH_SIZE = 32\n", - "\n", - "index_generator = NpIndexGenerator(\n", - " batch_size=BATCH_SIZE,\n", - " num_updates=NUM_UPDATES,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Torch Dataset definition\n", - "\n", - "This torch Dataset is used to preprocess the data using the `__getitem__` function.\n", - "\n", - "This torch Dataset needs to have a specific `__init__` signature, that must contain (self, data_from_opener, is_inference).\n", - "\n", - "The `__getitem__` function is expected to return (inputs, outputs) if `is_inference` is `False`, else only the inputs.\n", - "This behavior can be changed by re-writing the `_local_train` or `predict` methods.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "class TorchDataset(torch.utils.data.Dataset):\n", - " def __init__(self, data_from_opener, is_inference: bool):\n", - " self.x = data_from_opener[\"images\"]\n", - " self.y = data_from_opener[\"labels\"]\n", - " self.is_inference = is_inference\n", - "\n", - " def __getitem__(self, idx):\n", - " if self.is_inference:\n", - " x = torch.FloatTensor(self.x[idx][None, ...]) / 255\n", - " return x\n", - "\n", - " else:\n", - " x = torch.FloatTensor(self.x[idx][None, ...]) / 255\n", - "\n", - " y = torch.tensor(self.y[idx]).type(torch.int64)\n", - " y = F.one_hot(y, 10)\n", - " y = y.type(torch.float32)\n", - "\n", - " return x, y\n", - "\n", - " def __len__(self):\n", - " return len(self.x)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Cyclic Strategy implementation\n", - "\n", - "A FL strategy specifies how to train a model on distributed data.\n", - "\n", - "The **Cyclic Strategy** passes the model from an organization to the next one, until all\n", - "the data available in Substra has been sequentially presented to the model.\n", - "\n", - "This is not the most efficient strategy. The model will overfit the last dataset it sees,\n", - "and the order of training will impact the performances of the model. But we will use this implementation\n", - "as an example to explain and show how to implement your own strategies using SubstraFL.\n", - "\n", - "To instantiate this new strategy, we need to overwrite three methods:\n", - "\n", - "- `initialization_round`, to indicate what tasks to execute at round 0, in order to setup the variable\n", - " and be able to compute the performances of the model before any training.\n", - "- `perform_round`, to indicate what tasks and in which order we need to compute to execute a round of the strategy.\n", - "- `perform_evaluation`, to indicate how to compute the predictions and performances .\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from typing import Any\n", - "from typing import List\n", - "from typing import Optional\n", - "from typing import Dict\n", - "from typing import Callable\n", - "\n", - "from substrafl import strategies\n", - "from substrafl.algorithms.algo import Algo\n", - "from substrafl.nodes.aggregation_node import AggregationNode\n", - "from substrafl.nodes.test_data_node import TestDataNode\n", - "from substrafl.nodes.train_data_node import TrainDataNode\n", - "\n", - "\n", - "class CyclicStrategy(strategies.Strategy):\n", - " \"\"\"The base class Strategy proposes a default compute plan structure\n", - " in its ``build_compute_plan``method implementation, dedicated to Federated Learning compute plan.\n", - " This method calls ``initialization_round`` at round 0, and then repeats ``perform_round`` for ``num_rounds``.\n", - "\n", - " The default ``build_compute_plan`` implementation also takes into account the given evaluation\n", - " strategy to trigger the tests tasks when needed.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " algo: Algo,\n", - " metric_functions: Optional[Dict[str, Callable]] = None,\n", - " *args,\n", - " **kwargs,\n", - " ):\n", - " \"\"\"\n", - " It is possible to add any arguments to a Strategy. It is important to pass these arguments as\n", - " args or kwargs to the parent class, using the super().__init__(...) method.\n", - " Indeed, SubstraFL does not use the instance of the object. It re-instantiates them at each new task\n", - " using the args and kwargs passed to the parent class, and uses the save and load local state method to retrieve\n", - " its state.\n", - "\n", - " Args:\n", - " algo (Algo): A Strategy takes an Algo as argument, in order to deal with framework\n", - " specific function in a dedicated object.\n", - " metric_functions (Optional[Dict[str, Callable]]):\n", - " list of Functions that implement the different metrics. If a Dict is given, the keys will be used to\n", - " register the result of the associated function. If a Function or a List is given, function.__name__\n", - " will be used to store the result.\n", - " \"\"\"\n", - " super().__init__(algo=algo, metric_functions=metric_functions, *args, **kwargs)\n", - "\n", - " self._cyclic_local_state = None\n", - " self._cyclic_shared_state = None\n", - "\n", - " @property\n", - " def name(self) -> str:\n", - " \"\"\"The name of the strategy. Useful to indicate which Algo\n", - " are compatible or aren't with this strategy.\n", - "\n", - " Returns:\n", - " str: Name of the strategy\n", - " \"\"\"\n", - " return \"Cyclic Strategy\"\n", - "\n", - " def initialization_round(\n", - " self,\n", - " *,\n", - " train_data_nodes: List[TrainDataNode],\n", - " clean_models: bool,\n", - " round_idx: Optional[int] = 0,\n", - " additional_orgs_permissions: Optional[set] = None,\n", - " ):\n", - " \"\"\"The ``initialization_round`` function is called at round 0 on the\n", - " ``build_compute_plan`` method. In our strategy, we want to initialize\n", - " ``_cyclic_local_state`` in order to be able to test the model before\n", - " any training.\n", - "\n", - " We only initialize the model on the first train data node.\n", - "\n", - " Args:\n", - " train_data_nodes (List[TrainDataNode]): Train data nodes representing the different\n", - " organizations containing data we want to train on.\n", - " clean_models (bool): Boolean to indicate if we want to keep intermediate shared states.\n", - " Only taken into account in ``remote`` mode.\n", - " round_idx (Optional[int], optional): Current round index. The initialization round is zero by default,\n", - " but you are free to change it in the ``build_compute_plan`` method. Defaults to 0.\n", - " additional_orgs_permissions (Optional[set], optional): additional organization ids that could\n", - " have access to the outputs the task. In our case, this corresponds to the organization\n", - " containing test data nodes, in order to provide access to the model and to allow to\n", - " use it on the test data.\n", - " \"\"\"\n", - " first_train_data_node = train_data_nodes[0]\n", - "\n", - " # The algo.initialize method is an empty method useful to load all python object to the platform.\n", - " self._cyclic_local_state = first_train_data_node.init_states(\n", - " operation=self.algo.initialize(\n", - " _algo_name=f\"Initializing with {self.algo.__class__.__name__}\",\n", - " ),\n", - " round_idx=round_idx,\n", - " authorized_ids=set([first_train_data_node.organization_id]) | additional_orgs_permissions,\n", - " clean_models=clean_models,\n", - " )\n", - "\n", - " def perform_round(\n", - " self,\n", - " *,\n", - " train_data_nodes: List[TrainDataNode],\n", - " aggregation_node: Optional[AggregationNode],\n", - " round_idx: int,\n", - " clean_models: bool,\n", - " additional_orgs_permissions: Optional[set] = None,\n", - " ):\n", - " \"\"\"This method is called at each round to perform a series of task. For the cyclic\n", - " strategy we want to design, a round is a full cycle over the different train data\n", - " nodes.\n", - " We link the output of a computed task directly to the next one.\n", - "\n", - " Args:\n", - " train_data_nodes (List[TrainDataNode]): Train data nodes representing the different\n", - " organizations containing data we want to train on.\n", - " aggregation_node (List[AggregationNode]): In the case of the Cyclic Strategy, there is no\n", - " aggregation tasks so no need for AggregationNode.\n", - " clean_models (bool): Boolean to indicate if we want to keep intermediate shared states.\n", - " Only taken into account in ``remote`` mode.\n", - " round_idx (Optional[int], optional): Current round index.\n", - " additional_orgs_permissions (Optional[set], optional): additional organization ids that could\n", - " have access to the outputs the task. In our case, this will correspond to the organization\n", - " containing test data nodes, in order to provide access to the model and to allow to\n", - " use it on the test data.\n", - " \"\"\"\n", - " for i, node in enumerate(train_data_nodes):\n", - " # We get the next train_data_node in order to add the organization of the node\n", - " # to the authorized_ids\n", - " next_train_data_node = train_data_nodes[(i + 1) % len(train_data_nodes)]\n", - "\n", - " self._cyclic_local_state, self._cyclic_shared_state = node.update_states(\n", - " operation=self.algo.train(\n", - " node.data_sample_keys,\n", - " shared_state=self._cyclic_shared_state,\n", - " _algo_name=f\"Training with {self.algo.__class__.__name__}\",\n", - " ),\n", - " local_state=self._cyclic_local_state,\n", - " round_idx=round_idx,\n", - " authorized_ids=set([next_train_data_node.organization_id]) | additional_orgs_permissions,\n", - " aggregation_id=None,\n", - " clean_models=clean_models,\n", - " )\n", - "\n", - " def perform_evaluation(\n", - " self,\n", - " test_data_nodes: List[TestDataNode],\n", - " train_data_nodes: List[TrainDataNode],\n", - " round_idx: int,\n", - " ):\n", - " \"\"\"This method is called regarding the given evaluation strategy. If the round is included\n", - " in the evaluation strategy, the ``perform_evaluation`` method will be called on the different concerned nodes.\n", - "\n", - " We are using the last computed ``_cyclic_local_state`` to feed the test task, which mean that we will\n", - " always test the model after its training on the last train data nodes of the list.\n", - "\n", - " Args:\n", - " test_data_nodes (List[TestDataNode]): List of all the register test data nodes containing data\n", - " we want to test on.\n", - " train_data_nodes (List[TrainDataNode]): List of all the register train data nodes.\n", - " round_idx (int): Current round index.\n", - " \"\"\"\n", - " for test_node in test_data_nodes:\n", - " test_node.update_states(\n", - " traintask_id=self._cyclic_local_state.key,\n", - " operation=self.evaluate(\n", - " data_samples=test_node.data_sample_keys,\n", - " _algo_name=f\"Evaluating with {self.__class__.__name__}\",\n", - " ),\n", - " round_idx=round_idx,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Torch Cyclic Algo implementation\n", - "\n", - "A SubstraFL Algo gathers all the defined elements that run locally in each organization.\n", - "This is the only SubstraFL object that is framework specific (here PyTorch specific).\n", - "\n", - "In the case of our **Cyclic Strategy**, we need to use the TorchAlgo base class, and\n", - "overwrite the `strategies` property and the `train` method to ensure that we output\n", - "the shared state we need for our Federated Learning compute plan.\n", - "\n", - "For the **Cyclic Strategy**, the **shared state** will be directly the **model parameters**. We will\n", - "retrieve the model from the shared state we receive and send the new parameters updated after\n", - "the local training.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from substrafl.algorithms.pytorch.torch_base_algo import TorchAlgo\n", - "from substrafl.remote import remote_data\n", - "from substrafl.algorithms.pytorch import weight_manager\n", - "\n", - "\n", - "class TorchCyclicAlgo(TorchAlgo):\n", - " \"\"\"We create here the base class to be inherited for SubstraFL algorithms.\n", - " An Algo is a SubstraFL object that contains all framework specific functions.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " model: torch.nn.Module,\n", - " criterion: torch.nn.modules.loss._Loss,\n", - " optimizer: torch.optim.Optimizer,\n", - " index_generator: NpIndexGenerator,\n", - " dataset: torch.utils.data.Dataset,\n", - " seed: Optional[int] = None,\n", - " use_gpu: bool = True,\n", - " *args,\n", - " **kwargs,\n", - " ):\n", - " \"\"\"It is possible to add any arguments to an Algo. It is important to pass these arguments as\n", - " args or kwargs to the parent class, using the super().__init__(...) method.\n", - " Indeed, SubstraFL does not use the instance of the object. It re-instantiates them at each new task\n", - " using the args and kwargs passed to the parent class, and the save and load local state method to retrieve the\n", - " right state.\n", - "\n", - " Args:\n", - " model (torch.nn.modules.module.Module): A torch model.\n", - " criterion (torch.nn.modules.loss._Loss): A torch criterion (loss).\n", - " optimizer (torch.optim.Optimizer): A torch optimizer linked to the model.\n", - " index_generator (BaseIndexGenerator): a stateful index generator.\n", - " dataset (torch.utils.data.Dataset): an instantiable dataset class whose ``__init__`` arguments are\n", - " ``x``, ``y`` and ``is_inference``.\n", - " seed (typing.Optional[int]): Seed set at the algo initialization on each organization. Defaults to None.\n", - " use_gpu (bool): Whether to use the GPUs if they are available. Defaults to True.\n", - " \"\"\"\n", - " super().__init__(\n", - " model=model,\n", - " criterion=criterion,\n", - " optimizer=optimizer,\n", - " index_generator=index_generator,\n", - " dataset=dataset,\n", - " scheduler=None,\n", - " seed=seed,\n", - " use_gpu=use_gpu,\n", - " *args,\n", - " **kwargs,\n", - " )\n", - "\n", - " @property\n", - " def strategies(self) -> List[str]:\n", - " \"\"\"List of compatible strategies.\n", - "\n", - " Returns:\n", - " List[str]: list of compatible strategy name.\n", - " \"\"\"\n", - " return [\"Cyclic Strategy\"]\n", - "\n", - " @remote_data\n", - " def train(\n", - " self,\n", - " data_from_opener: Any,\n", - " shared_state: Optional[dict] = None,\n", - " ) -> dict:\n", - " \"\"\"This method decorated with ``@remote_data`` is a method that is executed inside\n", - " the train tasks of our strategy.\n", - " The decorator is used to retrieve the entire Algo object inside the task, to be able to access all values\n", - " useful for the training (such as the model, the optimizer, etc...).\n", - " The objective is to realize the local training on given data samples, and send the right shared state\n", - " to the next task.\n", - "\n", - " Args:\n", - " data_from_opener (Any): data_from_opener are the output of the ``get_data`` method of an opener. This opener\n", - " access the data of a train data nodes, and transforms them to feed methods decorated with\n", - " ``@remote_data``.\n", - " shared_state (Optional[dict], optional): a shared state is a dictionary containing the necessary values\n", - " to use from the previous trainings of the compute plan and initialize the model with it. In our case,\n", - " the shared state is the model parameters obtained after the local train on the previous organization.\n", - " The shared state is equal to None it is the first training of the compute plan.\n", - "\n", - " Returns:\n", - " dict: returns a dict corresponding to the shared state that will be used by the next train function on\n", - " a different organization.\n", - " \"\"\"\n", - " # Create torch dataset\n", - " train_dataset = self._dataset(data_from_opener, is_inference=False)\n", - "\n", - " if self._index_generator.n_samples is None:\n", - " # We need to initiate the index generator number of sample the first time we have access to\n", - " # the information.\n", - " self._index_generator.n_samples = len(train_dataset)\n", - "\n", - " # If the shared state is None, it means that this is the first training of the compute plan,\n", - " # and that we don't have a shared state to take into account yet.\n", - " if shared_state is not None:\n", - " assert self._index_generator.n_samples is not None\n", - " # The shared state is the average of the model parameters for all organizations. We set\n", - " # the model to these updated values.\n", - " model_parameters = [torch.from_numpy(x).to(self._device) for x in shared_state[\"model_parameters\"]]\n", - " weight_manager.set_parameters(\n", - " model=self._model,\n", - " parameters=model_parameters,\n", - " with_batch_norm_parameters=False,\n", - " )\n", - "\n", - " # We set the counter of updates to zero.\n", - " self._index_generator.reset_counter()\n", - "\n", - " # Train mode for torch model.\n", - " self._model.train()\n", - "\n", - " # Train the model.\n", - " self._local_train(train_dataset)\n", - "\n", - " # We verify that we trained the model on the right amount of updates.\n", - " self._index_generator.check_num_updates()\n", - "\n", - " # Eval mode for torch model.\n", - " self._model.eval()\n", - "\n", - " # We get the new model parameters values in order to send them in the shared states.\n", - " model_parameters = weight_manager.get_parameters(model=self._model, with_batch_norm_parameters=False)\n", - " new_shared_state = {\"model_parameters\": [p.cpu().detach().numpy() for p in model_parameters]}\n", - "\n", - " return new_shared_state" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To instantiate your algo, you need to instantiate it in a class with no argument. This comment is only valid when you\n", - "inherit from the TorchAlgo base class.\n", - "\n", - "The `TorchDataset` is passed **as a class** to the [TorchAlgo](https://docs.substra.org/en/stable/substrafl_doc/api/algorithms.html#torch-algorithms).\n", - "Indeed, this `TorchDataset` will be instantiated directly on the data provider organization.\n", - "\n", - "> **⚠ WARNING** \n", - "> It is possible to add any arguments to an Algo or a Strategy. It is important to pass these arguments as\n", - "> args or kwargs to the parent class, using the `super().__init__(...)` method.\n", - ">\n", - "> Indeed, SubstraFL does not use the instance of the object. It **re-instantiates** them at each new task\n", - "> using the args and kwargs passed to the parent class, and the save and load local state method to retrieve the\n", - "> right state.\n", - "\n", - "To summarize the `Algo` is the place to put all framework specific code we want to apply in tasks. It is often\n", - "the tasks that needs the data to be executed, and that are decorated with `@remote_data`.\n", - "\n", - "The `Strategy` contains the non-framework specific code, such as the `build_compute_plan` method, that creates the\n", - "graph of tasks, the **initialization round**, **perform round** and **perform predict** methods that links tasks to\n", - "each other and links the functions to the nodes.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "class MyAlgo(TorchCyclicAlgo):\n", - " def __init__(self):\n", - " super().__init__(\n", - " model=model,\n", - " criterion=criterion,\n", - " optimizer=optimizer,\n", - " index_generator=index_generator,\n", - " dataset=TorchDataset,\n", - " seed=seed,\n", - " use_gpu=False,\n", - " )\n", - "\n", - "\n", - "strategy = CyclicStrategy(algo=MyAlgo(), metric_functions={\"Accuracy\": accuracy, \"ROC AUC\": roc_auc})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Where to train where to aggregate\n", - "\n", - "We specify on which data we want to train our model, using the [TrainDataNode](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#traindatanode) object.\n", - "Here we train on the two datasets that we have registered earlier.\n", - "\n", - "The [AggregationNode](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#aggregationnode) specifies the organization on which the aggregation operation\n", - "will be computed.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from substrafl.nodes import TrainDataNode\n", - "\n", - "# Create the Train Data Nodes (or training tasks) and save them in a list\n", - "train_data_nodes = [\n", - " TrainDataNode(\n", - " organization_id=org_id,\n", - " data_manager_key=dataset_keys[org_id],\n", - " data_sample_keys=[train_datasample_keys[org_id]],\n", - " )\n", - " for org_id in DATA_PROVIDER_ORGS_ID\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Where and when to test\n", - "\n", - "With the same logic as the train nodes, we create [TestDataNode](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#testdatanode) to specify on which\n", - "data we want to test our model.\n", - "\n", - "The [Evaluation Strategy](https://docs.substra.org/en/stable/substrafl_doc/api/evaluation_strategy.html) defines where and at which frequency we\n", - "evaluate the model, using the given metric(s) that you registered in a previous section.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from substrafl.nodes import TestDataNode\n", - "from substrafl.evaluation_strategy import EvaluationStrategy\n", - "\n", - "# Create the Test Data Nodes (or testing tasks) and save them in a list\n", - "test_data_nodes = [\n", - " TestDataNode(\n", - " organization_id=org_id,\n", - " data_manager_key=dataset_keys[org_id],\n", - " data_sample_keys=[test_datasample_keys[org_id]],\n", - " )\n", - " for org_id in DATA_PROVIDER_ORGS_ID\n", - "]\n", - "\n", - "\n", - "# Test at the end of every round\n", - "my_eval_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, eval_frequency=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Running the experiment\n", - "\n", - "As a last step before launching our experiment, we need to specify the third parties dependencies required to run it.\n", - "The [Dependency](https://docs.substra.org/en/stable/substrafl_doc/api/dependency.html) object is instantiated in order to install the right libraries in\n", - "the Python environment of each organization.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from substrafl.dependency import Dependency\n", - "\n", - "dependencies = Dependency(pypi_dependencies=[\"numpy==1.26.4\", \"scikit-learn==1.5.0\", \"torch==2.2.1\", \"--extra-index-url https://download.pytorch.org/whl/cpu\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now have all the necessary objects to launch our experiment. Please see a summary below of all the objects we created so far:\n", - "\n", - "- A [Client](https://docs.substra.org/en/stable/documentation/references/sdk.html#client) to add or retrieve the assets of our experiment, using their keys to identify them.\n", - "- An [Torch Algorithms](https://docs.substra.org/en/stable/substrafl_doc/api/algorithms.html#torch-algorithms) to define the training parameters *(optimizer, train, function, predict function, etc...)*.\n", - "- A [Strategies](https://docs.substra.org/en/stable/substrafl_doc/api/strategies.html#strategies), to specify how to train the model on distributed data.\n", - "- [Train data nodes](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#traindatanode) to indicate on which data to train.\n", - "- An [Evaluation Strategy](https://docs.substra.org/en/stable/substrafl_doc/api/evaluation_strategy.html#evaluation-strategy), to define where and at which frequency we evaluate the model.\n", - "- An [Aggregation Node](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#aggregationnode), to specify the organization on which the aggregation operation will be computed.\n", - "- The **number of rounds**, a round being defined by a local training step followed by an aggregation operation.\n", - "- An **experiment folder** to save a summary of the operation made.\n", - "- The [Dependency](https://docs.substra.org/en/stable/substrafl_doc/api/dependency.html) to define the libraries on which the experiment needs to run." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from substrafl.experiment import execute_experiment\n", - "\n", - "# A round is defined by a local training step followed by an aggregation operation\n", - "NUM_ROUNDS = 3\n", - "\n", - "compute_plan = execute_experiment(\n", - " client=clients[ALGO_ORG_ID],\n", - " strategy=strategy,\n", - " train_data_nodes=train_data_nodes,\n", - " evaluation_strategy=my_eval_strategy,\n", - " aggregation_node=None,\n", - " num_rounds=NUM_ROUNDS,\n", - " experiment_folder=str(pathlib.Path.cwd() / \"tmp\" / \"experiment_summaries\"),\n", - " dependencies=dependencies,\n", - " clean_models=False,\n", - " name=\"Cyclic MNIST documentation example\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Explore the results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# The results will be available once the compute plan is completed\n", - "client_0.wait_compute_plan(compute_plan.key)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### List results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "performances_df = pd.DataFrame(client.get_performances(compute_plan.key).model_dump())\n", - "print(\"\\nPerformance Table: \\n\")\n", - "print(performances_df[[\"worker\", \"round_idx\", \"identifier\", \"performance\"]])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Plot results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n", - "fig.suptitle(\"Test dataset results\")\n", - "\n", - "axs[0].set_title(\"Accuracy\")\n", - "axs[1].set_title(\"ROC AUC\")\n", - "\n", - "for ax in axs.flat:\n", - " ax.set(xlabel=\"Rounds\", ylabel=\"Score\")\n", - "\n", - "\n", - "for org_id in DATA_PROVIDER_ORGS_ID:\n", - " org_df = performances_df[performances_df[\"worker\"] == org_id]\n", - " acc_df = org_df[org_df[\"identifier\"] == \"Accuracy\"]\n", - " axs[0].plot(acc_df[\"round_idx\"], acc_df[\"performance\"], label=org_id)\n", - "\n", - " auc_df = org_df[org_df[\"identifier\"] == \"ROC AUC\"]\n", - " axs[1].plot(auc_df[\"round_idx\"], auc_df[\"performance\"], label=org_id)\n", - "\n", - "plt.legend(loc=\"lower right\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Download a model\n", - "\n", - "After the experiment, you might be interested in downloading your trained model.\n", - "To do so, you will need the source code in order to reload your code architecture in memory.\n", - "You have the option to choose the client and the round you are interested in downloading.\n", - "\n", - "If `round_idx` is set to `None`, the last round will be selected by default.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from substrafl.model_loading import download_algo_state\n", - "\n", - "client_to_download_from = DATA_PROVIDER_ORGS_ID[-1]\n", - "round_idx = None\n", - "\n", - "algo = download_algo_state(\n", - " client=clients[client_to_download_from],\n", - " compute_plan_key=compute_plan.key,\n", - " round_idx=round_idx,\n", - ")\n", - "\n", - "model = algo.model\n", - "\n", - "print(model)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating Torch Cyclic strategy on MNIST dataset\n", + "\n", + "This example illustrates an advanced usage of SubstraFL and proposes to implement a new Federated Learning strategy,\n", + "called **Cyclic Strategy**, using the SubstraFL base classes.\n", + "This example runs on the [MNIST Dataset of handwritten digits](http://yann.lecun.com/exdb/mnist/) using PyTorch.\n", + "In this example, we work on 28x28 pixel sized grayscale images. This is a classification problem\n", + "aiming to recognize the number written on each image.\n", + "\n", + "The **Cyclic Strategy** consists in training locally a model on different organizations (or centers) sequentially (one after the other). We\n", + "consider a round of this strategy to be a full cycle of local trainings.\n", + "\n", + "This example shows an implementation of the CyclicTorchAlgo using\n", + "[TorchAlgo](https://docs.substra.org/en/stable/substrafl_doc/api/algorithms.html#torch-algorithms) as base class, and the CyclicStrategy implementation using\n", + "[Strategy](https://docs.substra.org/en/stable/substrafl_doc/api/strategies.html) as base class.\n", + "\n", + "This example does not use a deployed platform of Substra and runs in local mode.\n", + "\n", + "To run this example, you need to download and unzip the assets needed to run it in the same directory as used this example:\n", + "\n", + "- [assets required to run this example](../../../tmp/torch_cyclic_assets.zip)\n", + "\n", + "Please ensure to have all the libraries installed. A *requirements.txt* file is included in the zip file, where you can run the command `pip install -r requirements.txt` to install them.\n", + "\n", + "**Substra** and **SubstraFL** should already be installed. If not, follow the instructions described [here](https://docs.substra.org/en/stable/substrafl_doc/substrafl_overview.html#installation).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "This example runs with three organizations. Two organizations provide datasets, while a third\n", + "one provides the algorithm.\n", + "\n", + "In the following code cell, we define the different organizations needed for our FL experiment.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from substra import Client\n", + "\n", + "N_CLIENTS = 3\n", + "\n", + "client_0 = Client(client_name=\"org-1\")\n", + "client_1 = Client(client_name=\"org-2\")\n", + "client_2 = Client(client_name=\"org-3\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Every computation will run in `subprocess` mode, where everything runs locally in Python\n", + "subprocesses.\n", + "Other backend_types are:\n", + "\n", + "- `docker` mode where computations run locally in docker containers\n", + "- `remote` mode where computations run remotely (you need to have a deployed platform for that)\n", + "\n", + "To run in remote mode, use the following syntax:\n", + "\n", + "`client_remote = Client(backend_type=\"remote\", url=\"MY_BACKEND_URL\", username=\"my-username\", password=\"my-password\")`\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Create a dictionary to easily access each client from its human-friendly id\n", + "clients = {\n", + " client_0.organization_info().organization_id: client_0,\n", + " client_1.organization_info().organization_id: client_1,\n", + " client_2.organization_info().organization_id: client_2,\n", + "}\n", + "\n", + "# Store organization IDs\n", + "ORGS_ID = list(clients)\n", + "# Algo provider is defined as the first organization.\n", + "ALGO_ORG_ID = ORGS_ID[0]\n", + "# All organizations provide data in this cyclic setup.\n", + "DATA_PROVIDER_ORGS_ID = ORGS_ID" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data and metrics\n", + "\n", + "### Data preparation\n", + "\n", + "This section downloads (if needed) the **MNIST dataset** using the [torchvision library](https://pytorch.org/vision/stable/index.html).\n", + "It extracts the images from the raw files and locally creates a folder for each\n", + "organization.\n", + "\n", + "Each organization will have access to half the training data and half the test data (which\n", + "corresponds to **30,000**\n", + "images for training and **5,000** for testing each).\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import pathlib\n", + "from torch_cyclic_assets.dataset.cyclic_mnist_dataset import setup_mnist\n", + "\n", + "\n", + "# Create the temporary directory for generated data\n", + "(pathlib.Path.cwd() / \"tmp\").mkdir(exist_ok=True)\n", + "data_path = pathlib.Path.cwd() / \"tmp\" / \"data_mnist\"\n", + "\n", + "setup_mnist(data_path, len(DATA_PROVIDER_ORGS_ID))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dataset registration\n", + "\n", + "A [Dataset](https://docs.substra.org/en/stable/documentation/concepts.html#dataset) is composed of an **opener**, which is a Python script that can load\n", + "the data from the files in memory and a description markdown file.\n", + "The [Dataset](https://docs.substra.org/en/stable/documentation/concepts.html#dataset) object itself does not contain the data. The proper asset that contains the\n", + "data is the **datasample asset**.\n", + "\n", + "A **datasample** contains a local path to the data. A datasample can be linked to a dataset in order to add data to a\n", + "dataset.\n", + "\n", + "Data privacy is a key concept for Federated Learning experiments. That is why we set\n", + "[Permissions](https://docs.substra.org/en/stable/documentation/concepts.html#permissions) for [Assets](https://docs.substra.org/en/stable/documentation/concepts.html#assets) to determine how each organization\n", + "can access a specific asset.\n", + "You can read more about these concepts in the [User Guide](https://docs.substra.org/en/stable/documentation/concepts.htm).\n", + "\n", + "Note that metadata such as the assets' creation date and the asset owner are visible to all the organizations of a\n", + "network.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from substra.sdk.schemas import DatasetSpec\n", + "from substra.sdk.schemas import Permissions\n", + "from substra.sdk.schemas import DataSampleSpec\n", + "\n", + "assets_directory = pathlib.Path.cwd() / \"torch_cyclic_assets\"\n", + "dataset_keys = {}\n", + "train_datasample_keys = {}\n", + "test_datasample_keys = {}\n", + "\n", + "for i, org_id in enumerate(DATA_PROVIDER_ORGS_ID):\n", + " client = clients[org_id]\n", + "\n", + " permissions_dataset = Permissions(public=False, authorized_ids=[ALGO_ORG_ID])\n", + "\n", + " # DatasetSpec is the specification of a dataset. It makes sure every field\n", + " # is well-defined, and that our dataset is ready to be registered.\n", + " # The real dataset object is created in the add_dataset method.\n", + "\n", + " dataset = DatasetSpec(\n", + " name=\"MNIST\",\n", + " data_opener=assets_directory / \"dataset\" / \"cyclic_mnist_opener.py\",\n", + " description=assets_directory / \"dataset\" / \"description.md\",\n", + " permissions=permissions_dataset,\n", + " logs_permission=permissions_dataset,\n", + " )\n", + " dataset_keys[org_id] = client.add_dataset(dataset)\n", + " assert dataset_keys[org_id], \"Missing dataset key\"\n", + "\n", + " # Add the training data on each organization.\n", + " data_sample = DataSampleSpec(\n", + " data_manager_keys=[dataset_keys[org_id]],\n", + " path=data_path / f\"org_{i+1}\" / \"train\",\n", + " )\n", + " train_datasample_keys[org_id] = client.add_data_sample(data_sample)\n", + "\n", + " # Add the testing data on each organization.\n", + " data_sample = DataSampleSpec(\n", + " data_manager_keys=[dataset_keys[org_id]],\n", + " path=data_path / f\"org_{i+1}\" / \"test\",\n", + " )\n", + " test_datasample_keys[org_id] = client.add_data_sample(data_sample)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Metrics definition\n", + "\n", + "A metric is a function used to evaluate the performance of your model.\n", + "\n", + "To add a metric, you need to define a function that computes and returns a performance\n", + "from the data (as returned by the opener) and the predictions of the model.\n", + "\n", + "When using a Torch SubstraFL algorithm, the predictions are returned by the `predict` function.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.metrics import accuracy_score\n", + "from sklearn.metrics import roc_auc_score\n", + "import numpy as np\n", + "\n", + "\n", + "def accuracy(data_from_opener, predictions):\n", + " y_true = data_from_opener[\"labels\"]\n", + "\n", + " return accuracy_score(y_true, np.argmax(predictions, axis=1))\n", + "\n", + "\n", + "def roc_auc(data_from_opener, predictions):\n", + " y_true = data_from_opener[\"labels\"]\n", + "\n", + " n_class = np.max(y_true) + 1\n", + " y_true_one_hot = np.eye(n_class)[y_true]\n", + "\n", + " return roc_auc_score(y_true_one_hot, predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Machine learning components definition\n", + "\n", + "This section uses the PyTorch based SubstraFL API to simplify the definition of machine learning components.\n", + "However, SubstraFL is compatible with any machine learning framework.\n", + "\n", + "\n", + "In this section, you will:\n", + "\n", + "- Register a model and its dependencies\n", + "- Create a federated learning strategy\n", + "- Specify the training and aggregation nodes\n", + "- Specify the test nodes\n", + "- Actually run the computations\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model definition\n", + "\n", + "We choose to use a classic torch CNN as the model to train. The model architecture is defined by the user\n", + "independently of SubstraFL.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "\n", + "seed = 42\n", + "torch.manual_seed(seed)\n", + "\n", + "\n", + "class CNN(nn.Module):\n", + " def __init__(self):\n", + " super(CNN, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, kernel_size=5)\n", + " self.conv2 = nn.Conv2d(32, 32, kernel_size=5)\n", + " self.conv3 = nn.Conv2d(32, 64, kernel_size=5)\n", + " self.fc1 = nn.Linear(3 * 3 * 64, 256)\n", + " self.fc2 = nn.Linear(256, 10)\n", + "\n", + " def forward(self, x, eval=False):\n", + " x = F.relu(self.conv1(x))\n", + " x = F.relu(F.max_pool2d(self.conv2(x), 2))\n", + " x = F.dropout(x, p=0.5, training=not eval)\n", + " x = F.relu(F.max_pool2d(self.conv3(x), 2))\n", + " x = F.dropout(x, p=0.5, training=not eval)\n", + " x = x.view(-1, 3 * 3 * 64)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.dropout(x, p=0.5, training=not eval)\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + "\n", + "model = CNN()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Specifying on how much data to train\n", + "\n", + "To specify on how much data to train at each round, we use the `index_generator` object.\n", + "We specify the batch size and the number of batches (named `num_updates`) to consider for each round.\n", + "See [Index Generator](https://docs.substra.org/en/stable/substrafl_doc/substrafl_overview.html#index-generator) for more details.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from substrafl.index_generator import NpIndexGenerator\n", + "\n", + "# Number of model updates between each FL strategy aggregation.\n", + "NUM_UPDATES = 100\n", + "\n", + "# Number of samples per update.\n", + "BATCH_SIZE = 32\n", + "\n", + "index_generator = NpIndexGenerator(\n", + " batch_size=BATCH_SIZE,\n", + " num_updates=NUM_UPDATES,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Torch Dataset definition\n", + "\n", + "This torch Dataset is used to preprocess the data using the `__getitem__` function.\n", + "\n", + "This torch Dataset needs to have a specific `__init__` signature, that must contain (self, data_from_opener, is_inference).\n", + "\n", + "The `__getitem__` function is expected to return (inputs, outputs) if `is_inference` is `False`, else only the inputs.\n", + "This behavior can be changed by re-writing the `_local_train` or `predict` methods.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class TorchDataset(torch.utils.data.Dataset):\n", + " def __init__(self, data_from_opener, is_inference: bool):\n", + " self.x = data_from_opener[\"images\"]\n", + " self.y = data_from_opener[\"labels\"]\n", + " self.is_inference = is_inference\n", + "\n", + " def __getitem__(self, idx):\n", + " if self.is_inference:\n", + " x = torch.FloatTensor(self.x[idx][None, ...]) / 255\n", + " return x\n", + "\n", + " else:\n", + " x = torch.FloatTensor(self.x[idx][None, ...]) / 255\n", + "\n", + " y = torch.tensor(self.y[idx]).type(torch.int64)\n", + " y = F.one_hot(y, 10)\n", + " y = y.type(torch.float32)\n", + "\n", + " return x, y\n", + "\n", + " def __len__(self):\n", + " return len(self.x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cyclic Strategy implementation\n", + "\n", + "A FL strategy specifies how to train a model on distributed data.\n", + "\n", + "The **Cyclic Strategy** passes the model from an organization to the next one, until all\n", + "the data available in Substra has been sequentially presented to the model.\n", + "\n", + "This is not the most efficient strategy. The model will overfit the last dataset it sees,\n", + "and the order of training will impact the performances of the model. But we will use this implementation\n", + "as an example to explain and show how to implement your own strategies using SubstraFL.\n", + "\n", + "To instantiate this new strategy, we need to overwrite three methods:\n", + "\n", + "- `initialization_round`, to indicate what tasks to execute at round 0, in order to setup the variable\n", + " and be able to compute the performances of the model before any training.\n", + "- `perform_round`, to indicate what tasks and in which order we need to compute to execute a round of the strategy.\n", + "- `perform_evaluation`, to indicate how to compute the predictions and performances .\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from typing import Any\n", + "from typing import List\n", + "from typing import Optional\n", + "from typing import Dict\n", + "from typing import Callable\n", + "\n", + "from substrafl import strategies\n", + "from substrafl.algorithms.algo import Algo\n", + "from substrafl.nodes.aggregation_node import AggregationNode\n", + "from substrafl.nodes.test_data_node import TestDataNode\n", + "from substrafl.nodes.train_data_node import TrainDataNode\n", + "\n", + "\n", + "class CyclicStrategy(strategies.Strategy):\n", + " \"\"\"The base class Strategy proposes a default compute plan structure\n", + " in its ``build_compute_plan``method implementation, dedicated to Federated Learning compute plan.\n", + " This method calls ``initialization_round`` at round 0, and then repeats ``perform_round`` for ``num_rounds``.\n", + "\n", + " The default ``build_compute_plan`` implementation also takes into account the given evaluation\n", + " strategy to trigger the tests tasks when needed.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " algo: Algo,\n", + " metric_functions: Optional[Dict[str, Callable]] = None,\n", + " *args,\n", + " **kwargs,\n", + " ):\n", + " \"\"\"\n", + " It is possible to add any arguments to a Strategy. It is important to pass these arguments as\n", + " args or kwargs to the parent class, using the super().__init__(...) method.\n", + " Indeed, SubstraFL does not use the instance of the object. It re-instantiates them at each new task\n", + " using the args and kwargs passed to the parent class, and uses the save and load local state method to retrieve\n", + " its state.\n", + "\n", + " Args:\n", + " algo (Algo): A Strategy takes an Algo as argument, in order to deal with framework\n", + " specific function in a dedicated object.\n", + " metric_functions (Optional[Dict[str, Callable]]):\n", + " list of Functions that implement the different metrics. If a Dict is given, the keys will be used to\n", + " register the result of the associated function. If a Function or a List is given, function.__name__\n", + " will be used to store the result.\n", + " \"\"\"\n", + " super().__init__(algo=algo, metric_functions=metric_functions, *args, **kwargs)\n", + "\n", + " self._cyclic_local_state = None\n", + " self._cyclic_shared_state = None\n", + "\n", + " @property\n", + " def name(self) -> str:\n", + " \"\"\"The name of the strategy. Useful to indicate which Algo\n", + " are compatible or aren't with this strategy.\n", + "\n", + " Returns:\n", + " str: Name of the strategy\n", + " \"\"\"\n", + " return \"Cyclic Strategy\"\n", + "\n", + " def initialization_round(\n", + " self,\n", + " *,\n", + " train_data_nodes: List[TrainDataNode],\n", + " clean_models: bool,\n", + " round_idx: Optional[int] = 0,\n", + " additional_orgs_permissions: Optional[set] = None,\n", + " ):\n", + " \"\"\"The ``initialization_round`` function is called at round 0 on the\n", + " ``build_compute_plan`` method. In our strategy, we want to initialize\n", + " ``_cyclic_local_state`` in order to be able to test the model before\n", + " any training.\n", + "\n", + " We only initialize the model on the first train data node.\n", + "\n", + " Args:\n", + " train_data_nodes (List[TrainDataNode]): Train data nodes representing the different\n", + " organizations containing data we want to train on.\n", + " clean_models (bool): Boolean to indicate if we want to keep intermediate shared states.\n", + " Only taken into account in ``remote`` mode.\n", + " round_idx (Optional[int], optional): Current round index. The initialization round is zero by default,\n", + " but you are free to change it in the ``build_compute_plan`` method. Defaults to 0.\n", + " additional_orgs_permissions (Optional[set], optional): additional organization ids that could\n", + " have access to the outputs the task. In our case, this corresponds to the organization\n", + " containing test data nodes, in order to provide access to the model and to allow to\n", + " use it on the test data.\n", + " \"\"\"\n", + " first_train_data_node = train_data_nodes[0]\n", + "\n", + " # The algo.initialize method is an empty method useful to load all python object to the platform.\n", + " self._cyclic_local_state = first_train_data_node.init_states(\n", + " operation=self.algo.initialize(\n", + " _algo_name=f\"Initializing with {self.algo.__class__.__name__}\",\n", + " ),\n", + " round_idx=round_idx,\n", + " authorized_ids=set([first_train_data_node.organization_id]) | additional_orgs_permissions,\n", + " clean_models=clean_models,\n", + " )\n", + "\n", + " def perform_round(\n", + " self,\n", + " *,\n", + " train_data_nodes: List[TrainDataNode],\n", + " aggregation_node: Optional[AggregationNode],\n", + " round_idx: int,\n", + " clean_models: bool,\n", + " additional_orgs_permissions: Optional[set] = None,\n", + " ):\n", + " \"\"\"This method is called at each round to perform a series of task. For the cyclic\n", + " strategy we want to design, a round is a full cycle over the different train data\n", + " nodes.\n", + " We link the output of a computed task directly to the next one.\n", + "\n", + " Args:\n", + " train_data_nodes (List[TrainDataNode]): Train data nodes representing the different\n", + " organizations containing data we want to train on.\n", + " aggregation_node (List[AggregationNode]): In the case of the Cyclic Strategy, there is no\n", + " aggregation tasks so no need for AggregationNode.\n", + " clean_models (bool): Boolean to indicate if we want to keep intermediate shared states.\n", + " Only taken into account in ``remote`` mode.\n", + " round_idx (Optional[int], optional): Current round index.\n", + " additional_orgs_permissions (Optional[set], optional): additional organization ids that could\n", + " have access to the outputs the task. In our case, this will correspond to the organization\n", + " containing test data nodes, in order to provide access to the model and to allow to\n", + " use it on the test data.\n", + " \"\"\"\n", + " for i, node in enumerate(train_data_nodes):\n", + " # We get the next train_data_node in order to add the organization of the node\n", + " # to the authorized_ids\n", + " next_train_data_node = train_data_nodes[(i + 1) % len(train_data_nodes)]\n", + "\n", + " self._cyclic_local_state, self._cyclic_shared_state = node.update_states(\n", + " operation=self.algo.train(\n", + " node.data_sample_keys,\n", + " shared_state=self._cyclic_shared_state,\n", + " _algo_name=f\"Training with {self.algo.__class__.__name__}\",\n", + " ),\n", + " local_state=self._cyclic_local_state,\n", + " round_idx=round_idx,\n", + " authorized_ids=set([next_train_data_node.organization_id]) | additional_orgs_permissions,\n", + " aggregation_id=None,\n", + " clean_models=clean_models,\n", + " )\n", + "\n", + " def perform_evaluation(\n", + " self,\n", + " test_data_nodes: List[TestDataNode],\n", + " train_data_nodes: List[TrainDataNode],\n", + " round_idx: int,\n", + " ):\n", + " \"\"\"This method is called regarding the given evaluation strategy. If the round is included\n", + " in the evaluation strategy, the ``perform_evaluation`` method will be called on the different concerned nodes.\n", + "\n", + " We are using the last computed ``_cyclic_local_state`` to feed the test task, which mean that we will\n", + " always test the model after its training on the last train data nodes of the list.\n", + "\n", + " Args:\n", + " test_data_nodes (List[TestDataNode]): List of all the register test data nodes containing data\n", + " we want to test on.\n", + " train_data_nodes (List[TrainDataNode]): List of all the register train data nodes.\n", + " round_idx (int): Current round index.\n", + " \"\"\"\n", + " for test_node in test_data_nodes:\n", + " test_node.update_states(\n", + " traintask_id=self._cyclic_local_state.key,\n", + " operation=self.evaluate(\n", + " data_samples=test_node.data_sample_keys,\n", + " _algo_name=f\"Evaluating with {self.__class__.__name__}\",\n", + " ),\n", + " round_idx=round_idx,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Torch Cyclic Algo implementation\n", + "\n", + "A SubstraFL Algo gathers all the defined elements that run locally in each organization.\n", + "This is the only SubstraFL object that is framework specific (here PyTorch specific).\n", + "\n", + "In the case of our **Cyclic Strategy**, we need to use the TorchAlgo base class, and\n", + "overwrite the `strategies` property and the `train` method to ensure that we output\n", + "the shared state we need for our Federated Learning compute plan.\n", + "\n", + "For the **Cyclic Strategy**, the **shared state** will be directly the **model parameters**. We will\n", + "retrieve the model from the shared state we receive and send the new parameters updated after\n", + "the local training.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from substrafl.algorithms.pytorch.torch_base_algo import TorchAlgo\n", + "from substrafl.remote import remote_data\n", + "from substrafl.algorithms.pytorch import weight_manager\n", + "\n", + "\n", + "class TorchCyclicAlgo(TorchAlgo):\n", + " \"\"\"We create here the base class to be inherited for SubstraFL algorithms.\n", + " An Algo is a SubstraFL object that contains all framework specific functions.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " model: torch.nn.Module,\n", + " criterion: torch.nn.modules.loss._Loss,\n", + " optimizer: torch.optim.Optimizer,\n", + " index_generator: NpIndexGenerator,\n", + " dataset: torch.utils.data.Dataset,\n", + " seed: Optional[int] = None,\n", + " disable_gpu: bool = False,\n", + " *args,\n", + " **kwargs,\n", + " ):\n", + " \"\"\"It is possible to add any arguments to an Algo. It is important to pass these arguments as\n", + " args or kwargs to the parent class, using the super().__init__(...) method.\n", + " Indeed, SubstraFL does not use the instance of the object. It re-instantiates them at each new task\n", + " using the args and kwargs passed to the parent class, and the save and load local state method to retrieve the\n", + " right state.\n", + "\n", + " Args:\n", + " model (torch.nn.modules.module.Module): A torch model.\n", + " criterion (torch.nn.modules.loss._Loss): A torch criterion (loss).\n", + " optimizer (torch.optim.Optimizer): A torch optimizer linked to the model.\n", + " index_generator (BaseIndexGenerator): a stateful index generator.\n", + " dataset (torch.utils.data.Dataset): an instantiable dataset class whose ``__init__`` arguments are\n", + " ``x``, ``y`` and ``is_inference``.\n", + " seed (typing.Optional[int]): Seed set at the algo initialization on each organization. Defaults to None.\n", + " disable_gpu (bool): Force disabling GPU usage. If False, GPU will used if available, else CPU will be used. Defaults to False.\n", + " \"\"\"\n", + " super().__init__(\n", + " model=model,\n", + " criterion=criterion,\n", + " optimizer=optimizer,\n", + " index_generator=index_generator,\n", + " dataset=dataset,\n", + " scheduler=None,\n", + " seed=seed,\n", + " disable_gpu=disable_gpu,\n", + " *args,\n", + " **kwargs,\n", + " )\n", + "\n", + " @property\n", + " def strategies(self) -> List[str]:\n", + " \"\"\"List of compatible strategies.\n", + "\n", + " Returns:\n", + " List[str]: list of compatible strategy name.\n", + " \"\"\"\n", + " return [\"Cyclic Strategy\"]\n", + "\n", + " @remote_data\n", + " def train(\n", + " self,\n", + " data_from_opener: Any,\n", + " shared_state: Optional[dict] = None,\n", + " ) -> dict:\n", + " \"\"\"This method decorated with ``@remote_data`` is a method that is executed inside\n", + " the train tasks of our strategy.\n", + " The decorator is used to retrieve the entire Algo object inside the task, to be able to access all values\n", + " useful for the training (such as the model, the optimizer, etc...).\n", + " The objective is to realize the local training on given data samples, and send the right shared state\n", + " to the next task.\n", + "\n", + " Args:\n", + " data_from_opener (Any): data_from_opener are the output of the ``get_data`` method of an opener. This opener\n", + " access the data of a train data nodes, and transforms them to feed methods decorated with\n", + " ``@remote_data``.\n", + " shared_state (Optional[dict], optional): a shared state is a dictionary containing the necessary values\n", + " to use from the previous trainings of the compute plan and initialize the model with it. In our case,\n", + " the shared state is the model parameters obtained after the local train on the previous organization.\n", + " The shared state is equal to None it is the first training of the compute plan.\n", + "\n", + " Returns:\n", + " dict: returns a dict corresponding to the shared state that will be used by the next train function on\n", + " a different organization.\n", + " \"\"\"\n", + " # Create torch dataset\n", + " train_dataset = self._dataset(data_from_opener, is_inference=False)\n", + "\n", + " if self._index_generator.n_samples is None:\n", + " # We need to initiate the index generator number of sample the first time we have access to\n", + " # the information.\n", + " self._index_generator.n_samples = len(train_dataset)\n", + "\n", + " # If the shared state is None, it means that this is the first training of the compute plan,\n", + " # and that we don't have a shared state to take into account yet.\n", + " if shared_state is not None:\n", + " assert self._index_generator.n_samples is not None\n", + " # The shared state is the average of the model parameters for all organizations. We set\n", + " # the model to these updated values.\n", + " model_parameters = [torch.from_numpy(x).to(self._device) for x in shared_state[\"model_parameters\"]]\n", + " weight_manager.set_parameters(\n", + " model=self._model,\n", + " parameters=model_parameters,\n", + " with_batch_norm_parameters=False,\n", + " )\n", + "\n", + " # We set the counter of updates to zero.\n", + " self._index_generator.reset_counter()\n", + "\n", + " # Train mode for torch model.\n", + " self._model.train()\n", + "\n", + " # Train the model.\n", + " self._local_train(train_dataset)\n", + "\n", + " # We verify that we trained the model on the right amount of updates.\n", + " self._index_generator.check_num_updates()\n", + "\n", + " # Eval mode for torch model.\n", + " self._model.eval()\n", + "\n", + " # We get the new model parameters values in order to send them in the shared states.\n", + " model_parameters = weight_manager.get_parameters(model=self._model, with_batch_norm_parameters=False)\n", + " new_shared_state = {\"model_parameters\": [p.cpu().detach().numpy() for p in model_parameters]}\n", + "\n", + " return new_shared_state" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To instantiate your algo, you need to instantiate it in a class with no argument. This comment is only valid when you\n", + "inherit from the TorchAlgo base class.\n", + "\n", + "The `TorchDataset` is passed **as a class** to the [TorchAlgo](https://docs.substra.org/en/stable/substrafl_doc/api/algorithms.html#torch-algorithms).\n", + "Indeed, this `TorchDataset` will be instantiated directly on the data provider organization.\n", + "\n", + "> **⚠ WARNING** \n", + "> It is possible to add any arguments to an Algo or a Strategy. It is important to pass these arguments as\n", + "> args or kwargs to the parent class, using the `super().__init__(...)` method.\n", + ">\n", + "> Indeed, SubstraFL does not use the instance of the object. It **re-instantiates** them at each new task\n", + "> using the args and kwargs passed to the parent class, and the save and load local state method to retrieve the\n", + "> right state.\n", + "\n", + "To summarize the `Algo` is the place to put all framework specific code we want to apply in tasks. It is often\n", + "the tasks that needs the data to be executed, and that are decorated with `@remote_data`.\n", + "\n", + "The `Strategy` contains the non-framework specific code, such as the `build_compute_plan` method, that creates the\n", + "graph of tasks, the **initialization round**, **perform round** and **perform predict** methods that links tasks to\n", + "each other and links the functions to the nodes.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class MyAlgo(TorchCyclicAlgo):\n", + " def __init__(self):\n", + " super().__init__(\n", + " model=model,\n", + " criterion=criterion,\n", + " optimizer=optimizer,\n", + " index_generator=index_generator,\n", + " dataset=TorchDataset,\n", + " seed=seed,\n", + " disable_gpu=True,\n", + " )\n", + "\n", + "\n", + "strategy = CyclicStrategy(algo=MyAlgo(), metric_functions={\"Accuracy\": accuracy, \"ROC AUC\": roc_auc})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Where to train where to aggregate\n", + "\n", + "We specify on which data we want to train our model, using the [TrainDataNode](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#traindatanode) object.\n", + "Here we train on the two datasets that we have registered earlier.\n", + "\n", + "The [AggregationNode](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#aggregationnode) specifies the organization on which the aggregation operation\n", + "will be computed.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from substrafl.nodes import TrainDataNode\n", + "\n", + "# Create the Train Data Nodes (or training tasks) and save them in a list\n", + "train_data_nodes = [\n", + " TrainDataNode(\n", + " organization_id=org_id,\n", + " data_manager_key=dataset_keys[org_id],\n", + " data_sample_keys=[train_datasample_keys[org_id]],\n", + " )\n", + " for org_id in DATA_PROVIDER_ORGS_ID\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Where and when to test\n", + "\n", + "With the same logic as the train nodes, we create [TestDataNode](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#testdatanode) to specify on which\n", + "data we want to test our model.\n", + "\n", + "The [Evaluation Strategy](https://docs.substra.org/en/stable/substrafl_doc/api/evaluation_strategy.html) defines where and at which frequency we\n", + "evaluate the model, using the given metric(s) that you registered in a previous section.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from substrafl.nodes import TestDataNode\n", + "from substrafl.evaluation_strategy import EvaluationStrategy\n", + "\n", + "# Create the Test Data Nodes (or testing tasks) and save them in a list\n", + "test_data_nodes = [\n", + " TestDataNode(\n", + " organization_id=org_id,\n", + " data_manager_key=dataset_keys[org_id],\n", + " data_sample_keys=[test_datasample_keys[org_id]],\n", + " )\n", + " for org_id in DATA_PROVIDER_ORGS_ID\n", + "]\n", + "\n", + "\n", + "# Test at the end of every round\n", + "my_eval_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, eval_frequency=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running the experiment\n", + "\n", + "As a last step before launching our experiment, we need to specify the third parties dependencies required to run it.\n", + "The [Dependency](https://docs.substra.org/en/stable/substrafl_doc/api/dependency.html) object is instantiated in order to install the right libraries in\n", + "the Python environment of each organization.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from substrafl.dependency import Dependency\n", + "\n", + "dependencies = Dependency(pypi_dependencies=[\"numpy==1.26.4\", \"scikit-learn==1.5.0\", \"torch==2.2.1\", \"--extra-index-url https://download.pytorch.org/whl/cpu\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now have all the necessary objects to launch our experiment. Please see a summary below of all the objects we created so far:\n", + "\n", + "- A [Client](https://docs.substra.org/en/stable/documentation/references/sdk.html#client) to add or retrieve the assets of our experiment, using their keys to identify them.\n", + "- An [Torch Algorithms](https://docs.substra.org/en/stable/substrafl_doc/api/algorithms.html#torch-algorithms) to define the training parameters *(optimizer, train, function, predict function, etc...)*.\n", + "- A [Strategies](https://docs.substra.org/en/stable/substrafl_doc/api/strategies.html#strategies), to specify how to train the model on distributed data.\n", + "- [Train data nodes](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#traindatanode) to indicate on which data to train.\n", + "- An [Evaluation Strategy](https://docs.substra.org/en/stable/substrafl_doc/api/evaluation_strategy.html#evaluation-strategy), to define where and at which frequency we evaluate the model.\n", + "- An [Aggregation Node](https://docs.substra.org/en/stable/substrafl_doc/api/nodes.html#aggregationnode), to specify the organization on which the aggregation operation will be computed.\n", + "- The **number of rounds**, a round being defined by a local training step followed by an aggregation operation.\n", + "- An **experiment folder** to save a summary of the operation made.\n", + "- The [Dependency](https://docs.substra.org/en/stable/substrafl_doc/api/dependency.html) to define the libraries on which the experiment needs to run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from substrafl.experiment import execute_experiment\n", + "\n", + "# A round is defined by a local training step followed by an aggregation operation\n", + "NUM_ROUNDS = 3\n", + "\n", + "compute_plan = execute_experiment(\n", + " client=clients[ALGO_ORG_ID],\n", + " strategy=strategy,\n", + " train_data_nodes=train_data_nodes,\n", + " evaluation_strategy=my_eval_strategy,\n", + " aggregation_node=None,\n", + " num_rounds=NUM_ROUNDS,\n", + " experiment_folder=str(pathlib.Path.cwd() / \"tmp\" / \"experiment_summaries\"),\n", + " dependencies=dependencies,\n", + " clean_models=False,\n", + " name=\"Cyclic MNIST documentation example\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Explore the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# The results will be available once the compute plan is completed\n", + "client_0.wait_compute_plan(compute_plan.key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### List results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "performances_df = pd.DataFrame(client.get_performances(compute_plan.key).model_dump())\n", + "print(\"\\nPerformance Table: \\n\")\n", + "print(performances_df[[\"worker\", \"round_idx\", \"identifier\", \"performance\"]])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n", + "fig.suptitle(\"Test dataset results\")\n", + "\n", + "axs[0].set_title(\"Accuracy\")\n", + "axs[1].set_title(\"ROC AUC\")\n", + "\n", + "for ax in axs.flat:\n", + " ax.set(xlabel=\"Rounds\", ylabel=\"Score\")\n", + "\n", + "\n", + "for org_id in DATA_PROVIDER_ORGS_ID:\n", + " org_df = performances_df[performances_df[\"worker\"] == org_id]\n", + " acc_df = org_df[org_df[\"identifier\"] == \"Accuracy\"]\n", + " axs[0].plot(acc_df[\"round_idx\"], acc_df[\"performance\"], label=org_id)\n", + "\n", + " auc_df = org_df[org_df[\"identifier\"] == \"ROC AUC\"]\n", + " axs[1].plot(auc_df[\"round_idx\"], auc_df[\"performance\"], label=org_id)\n", + "\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download a model\n", + "\n", + "After the experiment, you might be interested in downloading your trained model.\n", + "To do so, you will need the source code in order to reload your code architecture in memory.\n", + "You have the option to choose the client and the round you are interested in downloading.\n", + "\n", + "If `round_idx` is set to `None`, the last round will be selected by default.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from substrafl.model_loading import download_algo_state\n", + "\n", + "client_to_download_from = DATA_PROVIDER_ORGS_ID[-1]\n", + "round_idx = None\n", + "\n", + "algo = download_algo_state(\n", + " client=clients[client_to_download_from],\n", + " compute_plan_key=compute_plan.key,\n", + " round_idx=round_idx,\n", + ")\n", + "\n", + "model = algo.model\n", + "\n", + "print(model)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file