Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

PB-159: remove weights from gRPC messages #298

Merged
merged 12 commits into from
Feb 20, 2020
6 changes: 2 additions & 4 deletions configs/example-config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ fraction_participants = 1.0
[storage]
# (Required) URL to the storage service to use
endpoint = "http://minio-dev:9000"
# (Required) Name of the bucket for storing the global model weights
global_weights_bucket = "xain-fl-aggregated-weights"
# (Required) Name of the bucket for retrieving the local model weights
local_weights_bucket = "xain-fl-participants-weights"
# (Required) Name of the bucket for storing the model weights
bucket = "xain-fl"
# (Required) AWS access key ID to use to authenticate to the storage service
access_key_id = "minio"
# (Required) AWS secret access to use to authenticate to the storage service
Expand Down
6 changes: 2 additions & 4 deletions configs/xain-fl.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ fraction_participants = 1.0
[storage]
# URL to the storage service to use
endpoint = "http://minio-dev:9000"
# Name of the bucket for storing the global model weights
global_weights_bucket = "xain-fl-aggregated-weights"
# Name of the bucket for retrieving the local model weights
local_weights_bucket = "xain-fl-participants-weights"
# Name of the bucket for storing the model weights
bucket = "xain-fl"
# AWS access key ID to use to authenticate to the storage service
access_key_id = "minio"
# AWS secret access to use to authenticate to the storage service
Expand Down
9 changes: 3 additions & 6 deletions docker-compose-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ services:
done;
echo Connected!;
mc config host add dev-minio http://minio-dev:9000 $${MINIO_ACCESS_KEY} $${MINIO_SECRET_KEY};
/usr/bin/mc mb -p dev-minio/xain-fl-temporary-weights;
/usr/bin/mc mb -p dev-minio/xain-fl-aggregated-weights;
/usr/bin/mc policy set upload dev-minio/xain-fl-temporary-weights;
/usr/bin/mc policy set download dev-minio/xain-fl-temporary-weights;
/usr/bin/mc policy set upload dev-minio/xain-fl-aggregated-weights;
/usr/bin/mc policy set download dev-minio/xain-fl-aggregated-weights;
/usr/bin/mc mb -p dev-minio/xain-fl;
/usr/bin/mc policy set upload dev-minio/xain-fl;
/usr/bin/mc policy set download dev-minio/xain-fl;
/usr/bin/mc admin trace -v -e dev-minio;
"

Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"numpy==1.15", # BSD
"grpcio==1.23", # Apache License 2.0
"structlog==19.2.0", # Apache License 2.0
"xain-proto==0.5.0", # Apache License 2.0
"xain-proto @ git+https://github.com/xainag/xain-proto.git@PB-159-use-s3-for-transfering-weights#egg=xain_proto-0.6.0&subdirectory=python", # Apache License 2.0
"boto3==1.10.48", # Apache License 2.0
"toml==0.10.0", # MIT
"schema~=0.7", # MIT
Expand All @@ -52,7 +52,8 @@
"pytest==5.3.2", # MIT license
"pytest-cov==2.8.1", # MIT
"pytest-watch==4.2.0", # MIT
"xain-sdk @ git+https://github.com/xainag/xain-sdk.git@development", # Apache License 2.0
"pytest-mock==2.0.0", # MIT
"xain-sdk@ git+https://github.com/xainag/xain-sdk.git@PB-159-use-s3-for-transfering-weights#egg=xain_sdk-0.6.0", # Apache License 2.0
]

docs_require = [
Expand Down
118 changes: 112 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,122 @@
import threading

import grpc
import numpy as np
from numpy import ndarray
import pytest
from xain_proto.fl import coordinator_pb2_grpc
from xain_proto.fl.coordinator_pb2 import EndTrainingRoundRequest

from xain_fl.coordinator.coordinator import Coordinator
from xain_fl.coordinator.coordinator_grpc import CoordinatorGrpc
from xain_fl.coordinator.heartbeat import monitor_heartbeats
from xain_fl.fl.coordinator.aggregate import ModelSumAggregator
from xain_fl.fl.coordinator.controller import IdController
from xain_fl.coordinator.metrics_store import (
AbstractMetricsStore,
NullObjectMetricsStore,
)
from xain_fl.coordinator.store import (
AbstractGlobalWeightsWriter,
AbstractLocalWeightsReader,
)
from xain_fl.fl.coordinator.aggregate import (
Aggregator,
ModelSumAggregator,
WeightedAverageAggregator,
)
from xain_fl.fl.coordinator.controller import Controller, IdController, RandomController

from .port_forwarding import ConnectionManager
from .store import MockS3Coordinator, MockS3Participant, MockS3Resource

# pylint: disable=redefined-outer-name


@pytest.fixture(scope="function")
def s3_mock_stores():
"""
Create a fake S3 store
"""

s3_resource = MockS3Resource()
participant_store = MockS3Participant(s3_resource)
coordinator_store = MockS3Coordinator(s3_resource)
return (coordinator_store, participant_store)


@pytest.fixture(scope="function")
def participant_store(s3_mock_stores):
"""Return an object the participants can use to read the global
weights and write their local weights

"""
return s3_mock_stores[1]


@pytest.fixture(scope="function")
def end_training_request(s3_mock_stores):
"""A fixture that returns a function that can be used to send an
``EndTrainingRequest`` to the coordinator.

"""
participant_store = s3_mock_stores[1]

def wrapped(
coordinator: Coordinator,
participant_id: str,
round: int = 0,
weights: ndarray = ndarray([]),
):
"""Write the local weights for the given round and the given
participant, and send an ``EndTrainingRequest`` on behalf of
that participant.

"""
participant_store.write_weights(participant_id, round, weights)
coordinator.on_message(
EndTrainingRoundRequest(participant_id=participant_id), participant_id
)

return wrapped


@pytest.fixture(scope="function")
def coordinator(s3_mock_stores):
"""
A function that instantiates a new coordinator.
"""
store: MockS3Coordinator = s3_mock_stores[0]
default_global_weights_writer: AbstractGlobalWeightsWriter = store
default_local_weights_reader: AbstractLocalWeightsReader = store

# pylint: disable=too-many-arguments
def wrapped(
global_weights_writer=default_global_weights_writer,
local_weights_reader=default_local_weights_reader,
metrics_store: AbstractMetricsStore = NullObjectMetricsStore(),
num_rounds: int = 1,
minimum_participants_in_round: int = 1,
fraction_of_participants: float = 1.0,
weights: ndarray = np.empty(shape=(0,)),
epochs: int = 1,
epoch_base: int = 0,
aggregator: Aggregator = WeightedAverageAggregator(),
controller: Controller = RandomController(),
):
return Coordinator(
global_weights_writer,
local_weights_reader,
metrics_store=metrics_store,
num_rounds=num_rounds,
minimum_participants_in_round=minimum_participants_in_round,
fraction_of_participants=fraction_of_participants,
weights=weights,
epochs=epochs,
epoch_base=epoch_base,
aggregator=aggregator,
controller=controller,
)

return wrapped


@pytest.fixture()
Expand Down Expand Up @@ -45,14 +151,14 @@ def coordinator_metrics_sample():


@pytest.fixture
def coordinator_service():
def coordinator_service(coordinator):
"""[summary]

.. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425)
"""

server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
coordinator = Coordinator(
coordinator = coordinator(
minimum_participants_in_round=10, fraction_of_participants=1.0
)
coordinator_grpc = CoordinatorGrpc(coordinator)
Expand All @@ -64,7 +170,7 @@ def coordinator_service():


@pytest.fixture
def mock_coordinator_service():
def mock_coordinator_service(coordinator):
"""[summary]

.. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425)
Expand All @@ -73,7 +179,7 @@ def mock_coordinator_service():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
agg = ModelSumAggregator()
ctrl = IdController()
coordinator = Coordinator(
coordinator = coordinator(
num_rounds=2,
minimum_participants_in_round=1,
fraction_of_participants=1.0,
Expand Down
85 changes: 67 additions & 18 deletions tests/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import typing

import numpy as np
from xain_sdk.store import S3GlobalWeightsReader, S3LocalWeightsWriter

from xain_fl.config import StorageConfig
from xain_fl.coordinator.store import S3GlobalWeightsWriter
from xain_fl.coordinator.store import S3GlobalWeightsWriter, S3LocalWeightsReader


class MockS3Resource:
Expand Down Expand Up @@ -62,43 +63,48 @@ def download_fileobj(self, key: str, buf: typing.IO):
self.reads[key] += 1


class MockS3Writer(S3GlobalWeightsWriter):
class MockS3Coordinator(S3GlobalWeightsWriter, S3LocalWeightsReader):
"""A partial mock of the
``xain-fl.coordinator.store.S3GlobalWeightsWriter`` class that
does not perform any IO. Instead, data is stored in memory.
``xain-fl.coordinator.store.S3GlobalWeightsWriter`` and
``xain-fl.coordinator.store.S3LocalWeightsReader`` class that does
not perform any IO. Instead, data is stored in memory.

"""

# We DO NOT want to call the parent class __init__, since it tries
# to initialize a connection to a non-existent external resource
#
# pylint: disable=super-init-not-called
def __init__(self):
def __init__(self, mock_s3_resource):
self.config = StorageConfig(
endpoint="endpoint",
access_key_id="access_key_id",
secret_access_key="secret_access_key",
global_weights_bucket="bucket",
local_weights_bucket="bucket",
bucket="bucket",
)
self.s3 = MockS3Resource()
self.s3 = mock_s3_resource

def assert_read(self, participant_id: str, round: int):
"""Check that the local weights for participant ``participant_id`` at
round ``round`` were read exactly once.

"""
key = f"{round}/{participant_id}"
reads = self.s3.reads[key]
assert reads == 1, f"got {reads} reads for round {key}, expected 1"

def assert_wrote(self, round: int, weights: np.ndarray):
"""Check that the given weights have been written to the store for the
given round.
given round.

Args:
weights (np.ndarray): weights to store
round (int): round to which the weights belong
weights: weights to store
round: round to which the weights belong
"""
writes = self.s3.writes[str(round)]
writes = self.s3.writes[f"{round}/global"]
# Under normal conditions, we should write data exactly once
assert writes == 1, f"got {writes} writes for round {round}, expected 1"
# If the arrays contains `NaN` we cannot compare them, so we
# replace them by zeros to do the comparison
stored_array = np.nan_to_num(self.s3.fake_store[str(round)])
expected_array = np.nan_to_num(weights)
assert np.array_equal(stored_array, expected_array)
np.testing.assert_array_equal(self.s3.fake_store[f"{round}/global"], weights)

def assert_didnt_write(self, round: int):
"""Check that the weights for the given round have NOT been written to the store.
Expand All @@ -107,4 +113,47 @@ def assert_didnt_write(self, round: int):
round (int): round to which the weights belong

"""
assert self.s3.writes[str(round)] == 0
assert self.s3.writes[f"{round}/global"] == 0


class MockS3Participant(S3LocalWeightsWriter, S3GlobalWeightsReader):
"""A partial mock of the ``xain_sdk.store.S3GlobalWeightsReader`` and
``xain_sdk.store.S3LocalWeightsWriter`` class that does not
perform any IO. Instead, data is stored in memory.

"""

def __init__(self, mock_s3_resource):
self.config = StorageConfig(
endpoint="endpoint",
access_key_id="access_key_id",
secret_access_key="secret_access_key",
bucket="bucket",
)
self.s3 = mock_s3_resource

def assert_wrote(self, participant_id: str, round: int, weights: np.ndarray):
"""Check that the given weights have been written to the store for the
given round.

Args:
weights: weights to store
participant_id: ID of the participant
round: round to which the weights belong
"""
key = f"{round}/{participant_id}"
writes = self.s3.writes[key]
assert writes == 1, f"got {writes} writes for {key}, expected 1"
np.testing.assert_array_equal(self.s3.fake_store[key], weights)

def assert_didnt_write(self, participant_id: str, round: int):
"""Check that the weights for the given round have NOT been written to
the store.

Args:
participant_id: ID of the participant
round: round to which the weights belong

"""
key = f"{round}/{participant_id}"
assert self.s3.writes[key] == 0
6 changes: 2 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def storage_sample():
"""
return {
"endpoint": "http://localhost:9000",
"global_weights_bucket": "aggregated_weights",
"local_weights_bucket": "participants_weights",
"bucket": "bucket",
"secret_access_key": "my-secret",
"access_key_id": "my-key-id",
}
Expand Down Expand Up @@ -135,8 +134,7 @@ def test_load_valid_config(config_sample): # pylint: disable=redefined-outer-na
assert config.ai.fraction_participants == 1.0

assert config.storage.endpoint == "http://localhost:9000"
assert config.storage.global_weights_bucket == "aggregated_weights"
assert config.storage.local_weights_bucket == "participants_weights"
assert config.storage.bucket == "bucket"
assert config.storage.secret_access_key == "my-secret"
assert config.storage.access_key_id == "my-key-id"

Expand Down
Loading