Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
Merge pull request #352 from xainag/PB-570-coordinator-service-test-i…
Browse files Browse the repository at this point in the history
…nfra

PB-570: add the first coordinator::core::Service test
  • Loading branch information
little-dude authored Mar 30, 2020
2 parents 11b630a + 5441e47 commit 1b8bcfc
Show file tree
Hide file tree
Showing 14 changed files with 390 additions and 5 deletions.
96 changes: 96 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ opentelemetry = { version = "0.2.0", optional = true }
tracing-opentelemetry = { version = "0.2.0", optional = true }
opentelemetry-jaeger = { version = "0.1.0", optional = true }

[dev-dependencies]
mockall = "0.6.0"

[[bin]]
name = "coordinator"
path = "src/bin/coordinator.rs"
Expand Down
2 changes: 0 additions & 2 deletions rust/src/aggregator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![allow(dead_code)]

pub mod api;
pub mod py_aggregator;
pub mod rpc;
Expand Down
2 changes: 1 addition & 1 deletion rust/src/aggregator/py_aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ pub struct PyAggregatorHandle {

impl Aggregator for PyAggregatorHandle {
type Error = ();
type AggregateFut = Pin<Box<dyn Future<Output = Result<Bytes, ()>>>>;
type AggregateFut = Pin<Box<dyn Future<Output = Result<Bytes, ()>> + Send>>;
type AddWeightsFut = Pin<Box<dyn Future<Output = Result<(), ()>> + Send>>;

fn add_weights(&mut self, weights: Bytes) -> Self::AddWeightsFut {
Expand Down
7 changes: 6 additions & 1 deletion rust/src/aggregator/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ mod inner {
}
}

pub use inner::{Rpc, RpcClient as Client};
pub use inner::Rpc;

#[cfg(test)]
pub use crate::tests::lib::rpc::aggregator::Client;
#[cfg(not(test))]
pub use inner::RpcClient as Client;

/// A server that serves a single client. A new `Server` is created
/// for each new client.
Expand Down
2 changes: 2 additions & 0 deletions rust/src/coordinator/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ mod heartbeat;
mod protocol;
mod service;

#[cfg(test)]
pub(crate) use self::service::ServiceRequests;
pub use self::service::{RequestError, Selector, Service, ServiceHandle};
3 changes: 2 additions & 1 deletion rust/src/coordinator/core/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use tokio::{
},
};

struct AggregationFuture(Pin<Box<dyn Future<Output = Result<(), ()>>>>);
struct AggregationFuture(Pin<Box<dyn Future<Output = Result<(), ()>> + Send>>);

impl Future for AggregationFuture {
type Output = Result<(), ()>;
Expand Down Expand Up @@ -458,6 +458,7 @@ where
}
}

#[derive(Debug)]
pub struct RequestError;

pub struct ServiceRequests(Pin<Box<dyn Stream<Item = Request> + Send>>);
Expand Down
3 changes: 3 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ extern crate serde;
pub mod aggregator;
pub mod common;
pub mod coordinator;

#[cfg(test)]
mod tests;
76 changes: 76 additions & 0 deletions rust/src/tests/coordinator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use crate::{
coordinator::{core::Service, models::HeartBeatResponse, settings::FederatedLearningSettings},
tests::lib::{
coordinator::{MaxSelector, ServiceHandle},
enable_logging,
rpc::aggregator::{Client, MockClient},
sleep_ms,
},
};
use futures::future;
use tokio::task::JoinHandle;

const AGGREGATOR_URL: &str = "http://localhost:8082";

fn start_service(settings: FederatedLearningSettings) -> (Client, ServiceHandle, JoinHandle<()>) {
// Make it easy to debug this test by setting the `TEST_LOGS`
// environment variable
enable_logging();

let rpc_client: Client = MockClient::default().into();

let (service_handle, service_requests) = ServiceHandle::new();

let service = Service::new(
MaxSelector,
settings,
AGGREGATOR_URL.to_string(),
rpc_client.clone(),
service_requests,
);
let join_handle = tokio::spawn(service);
(rpc_client, service_handle, join_handle)
}

/// Test a full cycle with a single round and a single participant.
#[tokio::test]
async fn full_cycle_1_round_1_participant() {
let settings = FederatedLearningSettings {
rounds: 1,
participants_ratio: 1.0,
min_clients: 1,
heartbeat_timeout: 10,
};
let (rpc_client, service_handle, _join_handle) = start_service(settings);

let id = service_handle.rendez_vous_accepted().await;
let round = service_handle.heartbeat_selected(id).await;
assert_eq!(round, 0);

rpc_client
.mock()
.expect_select()
.returning(|_, _| future::ready(Ok(Ok(()))));

let (url, _token) = service_handle.start_training_accepted(id).await;
assert_eq!(&url, AGGREGATOR_URL);

// pretend the client trained and sent its weights to the
// aggregator. The aggregator now sends an end training requests
// to the coordinator RPC server that we fake with the
// service_handle. The service should then trigger the aggregation
// and reject subsequent heartbeats and rendez-vous
rpc_client
.mock()
.expect_aggregate()
.returning(|_| future::ready(Ok(Ok(()))));

service_handle.end_training(id, true).await;
loop {
match service_handle.heartbeat(id).await {
HeartBeatResponse::StandBy => sleep_ms(10).await,
HeartBeatResponse::Finish => break,
_ => panic!("expected StandBy or Finish"),
}
}
}
Loading

0 comments on commit 1b8bcfc

Please sign in to comment.