Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
PB-159: remove weights from gRPC messages (#298)
Browse files Browse the repository at this point in the history
* PB-159: remove weights from gRPC messages

References:

https://xainag.atlassian.net/browse/PB-159

Needs to be merged along with:

- https://github.com/xainag/xain-proto/pull/25
- https://github.com/xainag/xain-sdk/pull/88
- #298

Summary:

Remove the weights from the gRPC messages. From now on, weights will
be exchanged via s3 buckets.

The sequence diagram below illustrate this new behavior.

At the beginning of a round (1) the selected participants send a
`StartTrainingRound` request, and the coordinator response with the
same `StartTrainingRoundResponse` that does not contain the global
weights anymore.

Instead, the participant fetches these weights from the store (2). S3
buckets are key-value stores, and the key for global weights is the
round number.

Then, the participant trains. Once done, it uploads its local weights
to the S3 bucket (3). The key is `<round_number>/<participant_id>`.

Finally (4), the participant sends it's `EndTrainingRequest`. Before
answering, the coordinator retrieves the local weights the participant
has uploaded.

_**Important note**: At the moment, the participants don't know their
ID, because the coordinator does not send it to them. Thus, they
currently generate a random ID when they start, and send it to the
coordinator so that it can retrieve the participant's weights. This is
why the `EndTrainingRoundRequest` currently has a `participant_id`
field._

```
    P                                C                      Store
1.  |   StartTrainingRoundRequest    |                        |
    | -----------------------------> |                        |
    |   StartTrainingRoundResponse   |                        |
    | <----------------------------- |                        |
    |                                |                        |
    |                Get global weights (key="round/global")  |
2.  | ------------------------------------------------------> |
    |                         Global weights                  |
    | <------------------------------------------------------ |
    |                                |                        |
    | [train...]                     |                        |
    |                                |                        |
3.  |       Set local weights (key="round/participant")       |
    | ------------------------------------------------------> |
    |                               Ok                        |
    | <------------------------------------------------------ |
    |                                |                        |
4.  |   EndTrainingRoundRequest      |                        |
    | -----------------------------> | Get local weights (key="round/participant")
    |                                | ---------------------> |
    |                                | Local weights          |
    |  EndTrainingRoundResponse      | <--------------------> |
    | <----------------------------- |                        |
```

At the end of the round, the coordinator writes the weights to the s3
bucket, using the next upcoming round number as key (see the sequence
diagram below).

```
P                                C                      Store
|   EndTrainingRoundRequest      |                        |
| -----------------------------> | Get local weights (key="round/participant")
|                                | ---------------------> |
|                                | Local weights          |
|  EndTrainingRoundResponse      | <--------------------> |
| <----------------------------- |                        |
|                                |                        |
|                                | Set global weights (key="round+1/participant")
|                                | ---------------------> |
|                                | Ok                     |
|                                | <--------------------> |
```

Implementation notes:

- Initially, we thought we would be using different buckets for the
  local and global weights. But for now, we use the same bucket for
  local and global weights for now

- We currently store the global weights under different keys. It turns
  out that this brings un-necessary complexity so we'll probably
  simplify this in the future

- For now, the coordinator doesn't send any storage information to the
  participants. Thus, the participants need to be configured with the
  storage information. In the future, the `StartTrainingRoundResponse`
  could contain the endpoint url, bucket name, etc.
  • Loading branch information
little-dude committed Feb 25, 2020
1 parent 8cd3d71 commit ca0f4ad
Show file tree
Hide file tree
Showing 13 changed files with 363 additions and 245 deletions.
6 changes: 2 additions & 4 deletions configs/example-config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ fraction_participants = 1.0
[storage]
# (Required) URL to the storage service to use
endpoint = "http://minio-dev:9000"
# (Required) Name of the bucket for storing the global model weights
global_weights_bucket = "xain-fl-aggregated-weights"
# (Required) Name of the bucket for retrieving the local model weights
local_weights_bucket = "xain-fl-participants-weights"
# (Required) Name of the bucket for storing the model weights
bucket = "xain-fl"
# (Required) AWS access key ID to use to authenticate to the storage service
access_key_id = "minio"
# (Required) AWS secret access to use to authenticate to the storage service
Expand Down
6 changes: 2 additions & 4 deletions configs/xain-fl.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ fraction_participants = 1.0
[storage]
# URL to the storage service to use
endpoint = "http://minio-dev:9000"
# Name of the bucket for storing the global model weights
global_weights_bucket = "xain-fl-aggregated-weights"
# Name of the bucket for retrieving the local model weights
local_weights_bucket = "xain-fl-participants-weights"
# Name of the bucket for storing the model weights
bucket = "xain-fl"
# AWS access key ID to use to authenticate to the storage service
access_key_id = "minio"
# AWS secret access to use to authenticate to the storage service
Expand Down
9 changes: 3 additions & 6 deletions docker-compose-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ services:
done;
echo Connected!;
mc config host add dev-minio http://minio-dev:9000 $${MINIO_ACCESS_KEY} $${MINIO_SECRET_KEY};
/usr/bin/mc mb -p dev-minio/xain-fl-temporary-weights;
/usr/bin/mc mb -p dev-minio/xain-fl-aggregated-weights;
/usr/bin/mc policy set upload dev-minio/xain-fl-temporary-weights;
/usr/bin/mc policy set download dev-minio/xain-fl-temporary-weights;
/usr/bin/mc policy set upload dev-minio/xain-fl-aggregated-weights;
/usr/bin/mc policy set download dev-minio/xain-fl-aggregated-weights;
/usr/bin/mc mb -p dev-minio/xain-fl;
/usr/bin/mc policy set upload dev-minio/xain-fl;
/usr/bin/mc policy set download dev-minio/xain-fl;
/usr/bin/mc admin trace -v -e dev-minio;
"
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"numpy==1.15", # BSD
"grpcio==1.23", # Apache License 2.0
"structlog==19.2.0", # Apache License 2.0
"xain-proto==0.5.0", # Apache License 2.0
"xain-proto @ git+https://github.com/xainag/xain-proto.git@PB-159-use-s3-for-transfering-weights#egg=xain_proto-0.6.0&subdirectory=python", # Apache License 2.0
"boto3==1.10.48", # Apache License 2.0
"toml==0.10.0", # MIT
"schema~=0.7", # MIT
Expand All @@ -52,7 +52,8 @@
"pytest==5.3.2", # MIT license
"pytest-cov==2.8.1", # MIT
"pytest-watch==4.2.0", # MIT
"xain-sdk @ git+https://github.com/xainag/xain-sdk.git@development", # Apache License 2.0
"pytest-mock==2.0.0", # MIT
"xain-sdk@ git+https://github.com/xainag/xain-sdk.git@PB-159-use-s3-for-transfering-weights#egg=xain_sdk-0.6.0", # Apache License 2.0
]

docs_require = [
Expand Down
118 changes: 112 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,122 @@
import threading

import grpc
import numpy as np
from numpy import ndarray
import pytest
from xain_proto.fl import coordinator_pb2_grpc
from xain_proto.fl.coordinator_pb2 import EndTrainingRoundRequest

from xain_fl.coordinator.coordinator import Coordinator
from xain_fl.coordinator.coordinator_grpc import CoordinatorGrpc
from xain_fl.coordinator.heartbeat import monitor_heartbeats
from xain_fl.fl.coordinator.aggregate import ModelSumAggregator
from xain_fl.fl.coordinator.controller import IdController
from xain_fl.coordinator.metrics_store import (
AbstractMetricsStore,
NullObjectMetricsStore,
)
from xain_fl.coordinator.store import (
AbstractGlobalWeightsWriter,
AbstractLocalWeightsReader,
)
from xain_fl.fl.coordinator.aggregate import (
Aggregator,
ModelSumAggregator,
WeightedAverageAggregator,
)
from xain_fl.fl.coordinator.controller import Controller, IdController, RandomController

from .port_forwarding import ConnectionManager
from .store import MockS3Coordinator, MockS3Participant, MockS3Resource

# pylint: disable=redefined-outer-name


@pytest.fixture(scope="function")
def s3_mock_stores():
"""
Create a fake S3 store
"""

s3_resource = MockS3Resource()
participant_store = MockS3Participant(s3_resource)
coordinator_store = MockS3Coordinator(s3_resource)
return (coordinator_store, participant_store)


@pytest.fixture(scope="function")
def participant_store(s3_mock_stores):
"""Return an object the participants can use to read the global
weights and write their local weights
"""
return s3_mock_stores[1]


@pytest.fixture(scope="function")
def end_training_request(s3_mock_stores):
"""A fixture that returns a function that can be used to send an
``EndTrainingRequest`` to the coordinator.
"""
participant_store = s3_mock_stores[1]

def wrapped(
coordinator: Coordinator,
participant_id: str,
round: int = 0,
weights: ndarray = ndarray([]),
):
"""Write the local weights for the given round and the given
participant, and send an ``EndTrainingRequest`` on behalf of
that participant.
"""
participant_store.write_weights(participant_id, round, weights)
coordinator.on_message(
EndTrainingRoundRequest(participant_id=participant_id), participant_id
)

return wrapped


@pytest.fixture(scope="function")
def coordinator(s3_mock_stores):
"""
A function that instantiates a new coordinator.
"""
store: MockS3Coordinator = s3_mock_stores[0]
default_global_weights_writer: AbstractGlobalWeightsWriter = store
default_local_weights_reader: AbstractLocalWeightsReader = store

# pylint: disable=too-many-arguments
def wrapped(
global_weights_writer=default_global_weights_writer,
local_weights_reader=default_local_weights_reader,
metrics_store: AbstractMetricsStore = NullObjectMetricsStore(),
num_rounds: int = 1,
minimum_participants_in_round: int = 1,
fraction_of_participants: float = 1.0,
weights: ndarray = np.empty(shape=(0,)),
epochs: int = 1,
epoch_base: int = 0,
aggregator: Aggregator = WeightedAverageAggregator(),
controller: Controller = RandomController(),
):
return Coordinator(
global_weights_writer,
local_weights_reader,
metrics_store=metrics_store,
num_rounds=num_rounds,
minimum_participants_in_round=minimum_participants_in_round,
fraction_of_participants=fraction_of_participants,
weights=weights,
epochs=epochs,
epoch_base=epoch_base,
aggregator=aggregator,
controller=controller,
)

return wrapped


@pytest.fixture()
Expand Down Expand Up @@ -45,14 +151,14 @@ def coordinator_metrics_sample():


@pytest.fixture
def coordinator_service():
def coordinator_service(coordinator):
"""[summary]
.. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425)
"""

server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
coordinator = Coordinator(
coordinator = coordinator(
minimum_participants_in_round=10, fraction_of_participants=1.0
)
coordinator_grpc = CoordinatorGrpc(coordinator)
Expand All @@ -64,7 +170,7 @@ def coordinator_service():


@pytest.fixture
def mock_coordinator_service():
def mock_coordinator_service(coordinator):
"""[summary]
.. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425)
Expand All @@ -73,7 +179,7 @@ def mock_coordinator_service():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
agg = ModelSumAggregator()
ctrl = IdController()
coordinator = Coordinator(
coordinator = coordinator(
num_rounds=2,
minimum_participants_in_round=1,
fraction_of_participants=1.0,
Expand Down
85 changes: 67 additions & 18 deletions tests/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import typing

import numpy as np
from xain_sdk.store import S3GlobalWeightsReader, S3LocalWeightsWriter

from xain_fl.config import StorageConfig
from xain_fl.coordinator.store import S3GlobalWeightsWriter
from xain_fl.coordinator.store import S3GlobalWeightsWriter, S3LocalWeightsReader


class MockS3Resource:
Expand Down Expand Up @@ -62,43 +63,48 @@ def download_fileobj(self, key: str, buf: typing.IO):
self.reads[key] += 1


class MockS3Writer(S3GlobalWeightsWriter):
class MockS3Coordinator(S3GlobalWeightsWriter, S3LocalWeightsReader):
"""A partial mock of the
``xain-fl.coordinator.store.S3GlobalWeightsWriter`` class that
does not perform any IO. Instead, data is stored in memory.
``xain-fl.coordinator.store.S3GlobalWeightsWriter`` and
``xain-fl.coordinator.store.S3LocalWeightsReader`` class that does
not perform any IO. Instead, data is stored in memory.
"""

# We DO NOT want to call the parent class __init__, since it tries
# to initialize a connection to a non-existent external resource
#
# pylint: disable=super-init-not-called
def __init__(self):
def __init__(self, mock_s3_resource):
self.config = StorageConfig(
endpoint="endpoint",
access_key_id="access_key_id",
secret_access_key="secret_access_key",
global_weights_bucket="bucket",
local_weights_bucket="bucket",
bucket="bucket",
)
self.s3 = MockS3Resource()
self.s3 = mock_s3_resource

def assert_read(self, participant_id: str, round: int):
"""Check that the local weights for participant ``participant_id`` at
round ``round`` were read exactly once.
"""
key = f"{round}/{participant_id}"
reads = self.s3.reads[key]
assert reads == 1, f"got {reads} reads for round {key}, expected 1"

def assert_wrote(self, round: int, weights: np.ndarray):
"""Check that the given weights have been written to the store for the
given round.
given round.
Args:
weights (np.ndarray): weights to store
round (int): round to which the weights belong
weights: weights to store
round: round to which the weights belong
"""
writes = self.s3.writes[str(round)]
writes = self.s3.writes[f"{round}/global"]
# Under normal conditions, we should write data exactly once
assert writes == 1, f"got {writes} writes for round {round}, expected 1"
# If the arrays contains `NaN` we cannot compare them, so we
# replace them by zeros to do the comparison
stored_array = np.nan_to_num(self.s3.fake_store[str(round)])
expected_array = np.nan_to_num(weights)
assert np.array_equal(stored_array, expected_array)
np.testing.assert_array_equal(self.s3.fake_store[f"{round}/global"], weights)

def assert_didnt_write(self, round: int):
"""Check that the weights for the given round have NOT been written to the store.
Expand All @@ -107,4 +113,47 @@ def assert_didnt_write(self, round: int):
round (int): round to which the weights belong
"""
assert self.s3.writes[str(round)] == 0
assert self.s3.writes[f"{round}/global"] == 0


class MockS3Participant(S3LocalWeightsWriter, S3GlobalWeightsReader):
"""A partial mock of the ``xain_sdk.store.S3GlobalWeightsReader`` and
``xain_sdk.store.S3LocalWeightsWriter`` class that does not
perform any IO. Instead, data is stored in memory.
"""

def __init__(self, mock_s3_resource):
self.config = StorageConfig(
endpoint="endpoint",
access_key_id="access_key_id",
secret_access_key="secret_access_key",
bucket="bucket",
)
self.s3 = mock_s3_resource

def assert_wrote(self, participant_id: str, round: int, weights: np.ndarray):
"""Check that the given weights have been written to the store for the
given round.
Args:
weights: weights to store
participant_id: ID of the participant
round: round to which the weights belong
"""
key = f"{round}/{participant_id}"
writes = self.s3.writes[key]
assert writes == 1, f"got {writes} writes for {key}, expected 1"
np.testing.assert_array_equal(self.s3.fake_store[key], weights)

def assert_didnt_write(self, participant_id: str, round: int):
"""Check that the weights for the given round have NOT been written to
the store.
Args:
participant_id: ID of the participant
round: round to which the weights belong
"""
key = f"{round}/{participant_id}"
assert self.s3.writes[key] == 0
6 changes: 2 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def storage_sample():
"""
return {
"endpoint": "http://localhost:9000",
"global_weights_bucket": "aggregated_weights",
"local_weights_bucket": "participants_weights",
"bucket": "bucket",
"secret_access_key": "my-secret",
"access_key_id": "my-key-id",
}
Expand Down Expand Up @@ -135,8 +134,7 @@ def test_load_valid_config(config_sample): # pylint: disable=redefined-outer-na
assert config.ai.fraction_participants == 1.0

assert config.storage.endpoint == "http://localhost:9000"
assert config.storage.global_weights_bucket == "aggregated_weights"
assert config.storage.local_weights_bucket == "participants_weights"
assert config.storage.bucket == "bucket"
assert config.storage.secret_access_key == "my-secret"
assert config.storage.access_key_id == "my-key-id"

Expand Down
Loading

0 comments on commit ca0f4ad

Please sign in to comment.