From ee7a2751aed36bdec48c7ec94d7e577f877897c5 Mon Sep 17 00:00:00 2001 From: little-dude Date: Tue, 4 Aug 2020 17:33:20 +0200 Subject: [PATCH] refactor tracing use of the `Traced` struct has two downsides: - Because we want to decouple the tracing logic from the application logic, we introduced generics all over the place for our code to work both with T and Traced. - We currently pass the `Traced.span()` down all the services chain, but `Span`s are more of less immutable and declared in macros at compile time, so just passing this span down is useless: we cannot at any field to it. This refactoring is twofold: 1. Remove all the genericity around tracing. This resulted in a lot of simplifications, especially in the state machine, but also in the pet message services since we could get rid of the `_PetMessageHandler` trait. 2. Replace Traced by a Request that is going to be used everywhere. Note that we currently just use `Request` in for the PET message services and the state machine, but it should be straightforward to use anywhere else. Unlike `Traced`, `Request` allows us to enrich the span attached to the request: `Request::map` creates a child span a the request is mapped. --- rust/src/bin/main.rs | 10 +- rust/src/client/request.rs | 2 +- rust/src/lib.rs | 2 +- rust/src/message/message.rs | 17 + rust/src/rest.rs | 23 +- rust/src/services/messages/message_parser.rs | 50 ++- rust/src/services/messages/mod.rs | 351 ++++++------------ .../services/messages/pre_processor/mod.rs | 15 +- rust/src/services/messages/state_machine.rs | 143 +------ rust/src/services/mod.rs | 10 +- .../services/tests/messages/message_parser.rs | 18 +- .../services/tests/messages/pre_processor.rs | 6 +- rust/src/state_machine/mod.rs | 62 ++-- rust/src/state_machine/phases/error.rs | 15 +- rust/src/state_machine/phases/idle.rs | 30 +- rust/src/state_machine/phases/mod.rs | 117 ++---- rust/src/state_machine/phases/shutdown.rs | 11 +- rust/src/state_machine/phases/sum.rs | 62 ++-- rust/src/state_machine/phases/sum2.rs | 57 ++- rust/src/state_machine/phases/unmask.rs | 13 +- rust/src/state_machine/phases/update.rs | 57 ++- rust/src/state_machine/requests.rs | 134 ++++--- rust/src/state_machine/tests/builder.rs | 18 +- rust/src/state_machine/tests/impls.rs | 86 +---- rust/src/state_machine/tests/mod.rs | 12 +- rust/src/utils/mod.rs | 3 +- rust/src/utils/request.rs | 132 +++++++ rust/src/utils/trace.rs | 46 --- 28 files changed, 624 insertions(+), 878 deletions(-) create mode 100644 rust/src/utils/request.rs delete mode 100644 rust/src/utils/trace.rs diff --git a/rust/src/bin/main.rs b/rust/src/bin/main.rs index b8e532704..658ce816d 100644 --- a/rust/src/bin/main.rs +++ b/rust/src/bin/main.rs @@ -1,13 +1,7 @@ use std::{path::PathBuf, process}; use structopt::StructOpt; use tracing_subscriber::*; -use xaynet::{ - rest, - services, - settings::Settings, - state_machine::{requests::Request, StateMachine}, - utils::trace::Traced, -}; +use xaynet::{rest, services, settings::Settings, state_machine::StateMachine}; #[derive(Debug, StructOpt)] #[structopt(name = "Coordinator")] @@ -41,7 +35,7 @@ async fn main() { sodiumoxide::init().unwrap(); let (state_machine, requests_tx, event_subscriber) = - StateMachine::>::new(pet_settings, mask_settings, model_settings).unwrap(); + StateMachine::new(pet_settings, mask_settings, model_settings).unwrap(); let fetcher = services::fetcher(&event_subscriber); let message_handler = services::message_handler(&event_subscriber, requests_tx); diff --git a/rust/src/client/request.rs b/rust/src/client/request.rs index e3bf6de2d..fd27ef587 100644 --- a/rust/src/client/request.rs +++ b/rust/src/client/request.rs @@ -53,7 +53,7 @@ impl Proxy { /// handling the message. /// * Returns `NetworkErr` if a network error occurs while posting the PET /// message. - pub async fn post_message(&self, msg: Vec) -> Result<(), ClientError> { + pub async fn post_message(&mut self, msg: Vec) -> Result<(), ClientError> { match self { InMem(_, hdl) => hdl .handle_message(msg) diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 09b724ccb..2bcc0b513 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -113,7 +113,7 @@ use self::crypto::{ /// An error related to insufficient system entropy for secrets at program startup. pub struct InitError; -#[derive(Debug, PartialEq, Display)] +#[derive(Debug, Display, Error)] /// Errors related to the PET protocol. pub enum PetError { InvalidMessage, diff --git a/rust/src/message/message.rs b/rust/src/message/message.rs index 472c0ced0..4daf02425 100644 --- a/rust/src/message/message.rs +++ b/rust/src/message/message.rs @@ -7,6 +7,7 @@ use std::borrow::Borrow; use anyhow::{anyhow, Context}; +use tracing::Span; use crate::{ certificate::Certificate, @@ -28,6 +29,7 @@ use crate::{ traits::{FromBytes, ToBytes}, DecodeError, }, + utils::Traceable, LocalSeedDict, }; @@ -195,3 +197,18 @@ impl<'a, 'b> MessageOpen<'a, 'b> { Ok(message) } } + +impl Traceable for MessageOwned { + fn make_span(&self) -> Span { + let message_type = match self.payload { + PayloadOwned::Sum(_) => "sum", + PayloadOwned::Update(_) => "update", + PayloadOwned::Sum2(_) => "sum2", + }; + error_span!( + "MessageOwned", + message_type = message_type, + message_length = self.buffer_length() + ) + } +} diff --git a/rust/src/rest.rs b/rust/src/rest.rs index b2dbd74fe..c12886602 100644 --- a/rust/src/rest.rs +++ b/rust/src/rest.rs @@ -24,14 +24,13 @@ pub async fn serve( pet_message_handler: MH, ) where F: Fetcher + Sync + Send + 'static, - MH: PetMessageHandler + Sync + Send + 'static, + MH: PetMessageHandler + Sync + Send + 'static + Clone, { let fetcher = Arc::new(fetcher); - let message_handler = Arc::new(pet_message_handler); let message = warp::path!("message") .and(warp::post()) .and(warp::body::bytes()) - .and(with_message_handler(message_handler.clone())) + .and(with_message_handler(pet_message_handler.clone())) .and_then(handle_message); let sum_dict = warp::path!("sums") @@ -81,15 +80,11 @@ pub async fn serve( /// Handles and responds to a PET message. async fn handle_message( body: Bytes, - handler: Arc, + mut handler: MH, ) -> Result { - let _ = handler - .as_ref() - .handle_message(body.to_vec()) - .await - .map_err(|e| { - warn!("failed to handle message: {:?}", e); - }); + let _ = handler.handle_message(body.to_vec()).await.map_err(|e| { + warn!("failed to handle message: {:?}", e); + }); Ok(warp::reply()) } @@ -227,9 +222,9 @@ async fn handle_params(fetcher: Arc) -> Result( - handler: Arc, -) -> impl Filter,), Error = Infallible> + Clone { +fn with_message_handler( + handler: MH, +) -> impl Filter + Clone { warp::any().map(move || handler.clone()) } diff --git a/rust/src/services/messages/message_parser.rs b/rust/src/services/messages/message_parser.rs index dd52401d5..fc1e09097 100644 --- a/rust/src/services/messages/message_parser.rs +++ b/rust/src/services/messages/message_parser.rs @@ -9,6 +9,7 @@ use rayon::ThreadPool; use thiserror::Error; use tokio::sync::oneshot; use tower::Service; +use tracing::Span; use crate::{ crypto::{encrypt::EncryptKeyPair, ByteObject}, @@ -23,7 +24,7 @@ use crate::{ events::{EventListener, EventSubscriber}, phases::PhaseName, }, - utils::trace::{Traceable, Traced}, + utils::{Request, Traceable}, Signature, }; @@ -58,16 +59,18 @@ impl MessageParserService { } } -/// Request type for the [`MessageParserService`]. -/// -/// It contains the encrypted message. +/// A buffer that represents an encrypted message. #[derive(From, Debug)] -pub struct MessageParserRequest(Vec); - -/// Response type for the [`MessageParserService`]. -/// -/// It contains the parsed message. -pub type MessageParserResponse = Result; +pub struct RawMessage>(T); + +impl Traceable for RawMessage +where + T: AsRef<[u8]>, +{ + fn make_span(&self) -> Span { + error_span!("raw_message", payload_len = self.0.as_ref().len()) + } +} /// Error type for the [`MessageParserService`] #[derive(Debug, Error)] @@ -93,7 +96,16 @@ pub enum MessageParserError { InternalError(String), } -impl Service> for MessageParserService { +/// Response type for the [`MessageParserService`] +pub type MessageParserResponse = Result; + +/// Request type for the [`MessageParserService`] +pub type MessageParserRequest = Request>; + +impl Service> for MessageParserService +where + T: AsRef<[u8]> + Send + 'static, +{ type Response = MessageParserResponse; type Error = std::convert::Infallible; @@ -107,7 +119,7 @@ impl Service> for MessageParserService { Poll::Ready(Ok(())) } - fn call(&mut self, req: Traced) -> Self::Future { + fn call(&mut self, req: MessageParserRequest) -> Self::Future { debug!("retrieving the current keys and current phase"); let keys_ev = self.keys_events.get_latest(); let phase_ev = self.phase_events.get_latest(); @@ -130,9 +142,9 @@ impl Service> for MessageParserService { trace!("spawning pre-processor handler on thread-pool"); self.thread_pool.spawn(move || { - let span = req.span().clone(); - let _enter = span.enter(); - let resp = handler.call(req.into_inner().0); + let span = req.span(); + let _span_guard = span.enter(); + let resp = handler.call(req.into_inner()); let _ = tx.send(resp); }); Either::Right(Box::pin(async move { @@ -156,9 +168,9 @@ struct Handler { impl Handler { /// Process the request. `data` is the encrypted PET message to /// process. - fn call(self, data: Vec) -> Result { + fn call>(self, data: RawMessage) -> MessageParserResponse { info!("decrypting message"); - let raw = self.decrypt(data)?; + let raw = self.decrypt(&data.0.as_ref())?; info!("parsing message header"); let header = self.parse_header(raw.as_slice())?; @@ -177,11 +189,11 @@ impl Handler { } /// Decrypt the given payload with the coordinator secret key - fn decrypt(&self, encrypted_message: Vec) -> Result, MessageParserError> { + fn decrypt(&self, encrypted_message: &[u8]) -> Result, MessageParserError> { Ok(self .keys .secret - .decrypt(&encrypted_message.as_ref(), &self.keys.public) + .decrypt(&encrypted_message, &self.keys.public) .map_err(|_| MessageParserError::Decrypt)?) } diff --git a/rust/src/services/messages/mod.rs b/rust/src/services/messages/mod.rs index 88774012e..0c4fa520e 100644 --- a/rust/src/services/messages/mod.rs +++ b/rust/src/services/messages/mod.rs @@ -28,46 +28,39 @@ pub use self::{ }, }; -use std::{ - pin::Pin, - task::{Context, Poll}, +use crate::{ + services::messages::message_parser::RawMessage, + utils::Traceable, + vendor::tracing_tower, }; -use crate::vendor::tracing_tower; -use futures::{future::poll_fn, Future}; +use futures::future::poll_fn; use thiserror::Error; use tower::{Service, ServiceBuilder}; -use tracing_futures::Instrument; -use uuid::Uuid; - -use crate::{ - message::message::MessageOwned, - utils::trace::{Traceable, Traced}, -}; -/// Associate an ID to the given request, and attach a span to the request. -fn make_traceable_request(req: R) -> Traced { - let id = Uuid::new_v4(); - let span = error_span!("request", id = ?id); - Traced::new(req, span) -} +use crate::{message::message::MessageOwned, utils::Request}; /// Return the [`tracing::Span`] associated to the given request. -fn req_span(req: &Traced) -> tracing::Span { - req.span().clone() +fn req_span(req: &Request) -> tracing::Span { + req.span() } /// Decorate the given service with a tracing middleware. -fn with_tracing(service: S) -> TracingService +fn with_tracing(service: S) -> TracedService where - S: Service>, + S: Service>, + T: Traceable, { ServiceBuilder::new() .layer(tracing_tower::layer(req_span as for<'r> fn(&'r _) -> _)) .service(service) } -type TracingService = tracing_tower::Service, fn(&Traced) -> tracing::Span>; +type TracedService = tracing_tower::Service, fn(&Request) -> tracing::Span>; + +type TracedMessageParser = TracedService>>; +type TracedPreProcessor = TracedService; +type TracedStateMachine = TracedService; /// Error returned by the [`PetMessageHandler`] methods. #[derive(Debug, Error)] @@ -80,150 +73,121 @@ pub enum PetMessageError { #[error("state machine failed to handle message: {0}")] StateMachine(StateMachineError), -} - -#[doc(hidden)] -#[async_trait] -pub trait _PetMessageHandler { - /// Parse an encrypted message - async fn call_parser(&self, enc_message: Traced>) -> MessageParserResponse; - - /// Pre-process a PET message - async fn call_pre_processor(&self, message: Traced) -> PreProcessorResponse; - /// Have a PET message processed by the state machine - async fn call_state_machine(&self, message: Traced) -> StateMachineResponse; + #[error("the service failed to process the request: {0}")] + ServiceError(Box), } /// A single interface for all the PET message processing sub-services /// ([`MessageParserService`], [`PreProcessorService`] and /// [`StateMachineService`]). #[async_trait] -pub trait PetMessageHandler { - /// Handle an incoming encrypted PET message form a participant. - async fn handle_message(&self, enc_message: Vec) -> Result<(), PetMessageError>; -} +pub trait PetMessageHandler: Send { + async fn handle_message( + &mut self, + // FIXME: this should take a `Request<_>` instead that should + // be created by the caller (in the rest layer). + req: Vec, + ) -> Result<(), PetMessageError> { + let req = Request::new(RawMessage::from(req)); + let metadata = req.metadata(); + let message = self.call_parser(req).await?; + + let req = Request::from_parts(metadata.clone(), message); + let message = self.call_pre_processor(req).await?; + + let req = Request::from_parts(metadata, message); + Ok(self.call_state_machine(req).await?) + } -#[async_trait] -impl PetMessageHandler for T -where - T: _PetMessageHandler + Sync, -{ - async fn handle_message(&self, enc_message: Vec) -> Result<(), PetMessageError> { - let req = make_traceable_request(enc_message); - let span = req.span().clone(); - let message = self - .call_parser(req) - .await - .map_err(PetMessageError::Parser)?; + /// Parse an encrypted message + async fn call_parser( + &mut self, + enc_message: MessageParserRequest>, + ) -> Result; - let req = Traced::new(message, span.clone()); - let message = self - .call_pre_processor(req) - .await - .map_err(PetMessageError::PreProcessor)?; + /// Pre-process a PET message + async fn call_pre_processor( + &mut self, + message: PreProcessorRequest, + ) -> Result; - let req = Traced::new(message, span.clone()); - Ok(self - .call_state_machine(req) - .await - .map_err(PetMessageError::StateMachine)?) - } + /// Have a PET message processed by the state machine + async fn call_state_machine( + &mut self, + message: StateMachineRequest, + ) -> Result<(), PetMessageError>; } #[async_trait] -impl _PetMessageHandler for PetMessageService +impl PetMessageHandler for PetMessageService where - Self: Clone - + Send - + Sync - + 'static - + Service, Response = MessageParserResponse> - + Service, Response = PreProcessorResponse> - + Service, Response = StateMachineResponse>, + Self: Send + Sync + 'static, - >>::Future: Send + 'static, - >>::Error: + MP: Service>, Response = MessageParserResponse> + Send + 'static, + >>>::Future: Send + 'static, + >>>::Error: Into>, - >>::Future: Send + 'static, - >>::Error: + PP: Service + Send + 'static, + >::Future: Send + 'static, + >::Error: Into>, - >>::Future: Send + 'static, - >>::Error: + SM: Service + Send + 'static, + >::Future: Send + 'static, + >::Error: Into>, { - async fn call_parser(&self, enc_message: Traced>) -> MessageParserResponse { - let span = enc_message.span().clone(); - let mut svc = self.clone(); - async move { - poll_fn(|cx| >>::poll_ready(&mut svc, cx)) - .await - .map_err(Into::into) - // FIXME: do not unwrap. For now it is fine because we - // actually only use MessageParserService directly, - // which never fails. - .unwrap(); - >>::call( - &mut svc, - enc_message.map(Into::into), - ) - .await - .map_err(Into::into) - // FIXME: do not unwrap. For now it is fine because we - // actually only use MessageParserService directly, - // which never fails. - .unwrap() - } - .instrument(span) + async fn call_parser( + &mut self, + enc_message: MessageParserRequest>, + ) -> Result { + poll_fn(|cx| { + >>>::poll_ready(&mut self.message_parser, cx) + }) + .await + // FIXME: we should actually downcast the error and + // distinguish between the various services errors we can + // have. Currently, this will just turn the error into a + // Box + .map_err(|e| PetMessageError::ServiceError(Into::into(e)))?; + + >>>::call( + &mut self.message_parser, + enc_message.map(Into::into), + ) .await + .map_err(|e| PetMessageError::ServiceError(Into::into(e)))? + .map_err(PetMessageError::Parser) } - async fn call_pre_processor(&self, message: Traced) -> PreProcessorResponse { - let span = message.span().clone(); - let mut svc = self.clone(); - async move { - poll_fn(|cx| >>::poll_ready(&mut svc, cx)) - .await - .map_err(Into::into) - // FIXME: do not unwrap. For now it is fine because we - // actually only use PreProcessorService directly, - // which never fails. - .unwrap(); - >>::call(&mut svc, message.map(Into::into)) - .await - .map_err(Into::into) - // FIXME: do not unwrap. For now it is fine because we - // actually only use PreProcessorService directly, - // which never fails. - .unwrap() - } - .instrument(span) - .await + async fn call_pre_processor( + &mut self, + message: PreProcessorRequest, + ) -> Result { + poll_fn(|cx| >::poll_ready(&mut self.pre_processor, cx)) + .await + .map_err(|e| PetMessageError::ServiceError(Into::into(e)))?; + + >::call(&mut self.pre_processor, message.map(Into::into)) + .await + .map_err(|e| PetMessageError::ServiceError(Into::into(e)))? + .map_err(PetMessageError::PreProcessor) } - async fn call_state_machine(&self, message: Traced) -> StateMachineResponse { - let span = message.span().clone(); - let mut svc = self.clone(); - async move { - poll_fn(|cx| >>::poll_ready(&mut svc, cx)) - .await - .map_err(Into::into) - // FIXME: do not unwrap. For now it is fine because we - // actually only use StateMachineService directly, - // which never fails. - .unwrap(); - >>::call(&mut svc, message.map(Into::into)) - .await - .map_err(Into::into) - // FIXME: do not unwrap. For now it is fine because we - // actually only use StateMachineService directly, - // which never fails. - .unwrap() - } - .instrument(span) - .await + async fn call_state_machine( + &mut self, + message: StateMachineRequest, + ) -> Result<(), PetMessageError> { + poll_fn(|cx| >::poll_ready(&mut self.state_machine, cx)) + .await + .map_err(|e| PetMessageError::ServiceError(Into::into(e)))?; + + >::call(&mut self.state_machine, message.map(Into::into)) + .await + .map_err(|e| PetMessageError::ServiceError(Into::into(e)))? + .map_err(PetMessageError::StateMachine) } } @@ -249,15 +213,11 @@ pub struct PetMessageService { } impl - PetMessageService< - TracingService, - TracingService, - TracingService, - > + PetMessageService, TracedPreProcessor, TracedStateMachine> where - MP: Service, Response = MessageParserResponse>, - PP: Service, Response = PreProcessorResponse>, - SM: Service, Response = StateMachineResponse>, + MP: Service>, Response = MessageParserResponse>, + PP: Service, + SM: Service, { /// Instantiate a new [`PetMessageService`] with the given sub-services pub fn new(message_parser: MP, pre_processor: PP, state_machine: SM) -> Self { @@ -268,96 +228,3 @@ where } } } - -impl Service> for PetMessageService -where - MP: Service, Response = MessageParserResponse> - + Clone - + Send - + 'static, - >>::Future: Send + 'static, - >>::Error: - Into>, -{ - type Response = MessageParserResponse; - type Error = Box; - #[allow(clippy::type_complexity)] - type Future = - Pin> + Send + 'static>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - >>::poll_ready(&mut self.message_parser, cx) - .map_err(Into::into) - } - - fn call(&mut self, req: Traced) -> Self::Future { - let mut svc = self.message_parser.clone(); - let fut = async move { - info!("calling the message parser service on the request"); - svc.call(req).await.map_err(Into::into) - }; - Box::pin(fut) - } -} - -impl Service> for PetMessageService -where - PP: Service, Response = PreProcessorResponse> - + Clone - + Send - + 'static, - >>::Future: Send + 'static, - >>::Error: - Into>, -{ - type Response = PreProcessorResponse; - type Error = Box; - #[allow(clippy::type_complexity)] - type Future = - Pin> + Send + 'static>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - >>::poll_ready(&mut self.pre_processor, cx) - .map_err(Into::into) - } - - fn call(&mut self, req: Traced) -> Self::Future { - let mut svc = self.pre_processor.clone(); - let fut = async move { - info!("calling the pre-processor service on the request"); - svc.call(req).await.map_err(Into::into) - }; - Box::pin(fut) - } -} - -impl Service> for PetMessageService -where - SM: Service, Response = StateMachineResponse> - + Clone - + Send - + 'static, - >>::Future: Send + 'static, - >>::Error: - Into>, -{ - type Response = StateMachineResponse; - type Error = Box; - #[allow(clippy::type_complexity)] - type Future = - Pin> + Send + 'static>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - >>::poll_ready(&mut self.state_machine, cx) - .map_err(Into::into) - } - - fn call(&mut self, req: Traced) -> Self::Future { - let mut svc = self.state_machine.clone(); - let fut = async move { - info!("calling the state machine service on the request"); - svc.call(req).await.map_err(Into::into) - }; - Box::pin(fut) - } -} diff --git a/rust/src/services/messages/pre_processor/mod.rs b/rust/src/services/messages/pre_processor/mod.rs index 2e3990903..b8c090986 100644 --- a/rust/src/services/messages/pre_processor/mod.rs +++ b/rust/src/services/messages/pre_processor/mod.rs @@ -9,7 +9,6 @@ pub use sum2::Sum2PreProcessorService; use std::{pin::Pin, task::Poll}; -use derive_more::From; use futures::{ future::{self, Future}, task::Context, @@ -24,7 +23,7 @@ use crate::{ events::{Event, EventListener, EventSubscriber}, phases::PhaseName, }, - utils::trace::{Traceable, Traced}, + utils::request::Request, }; /// A service for performing sanity checks and preparing incoming @@ -56,15 +55,13 @@ impl PreProcessorService { } } -/// Request type for [`PreProcessorService`]. It contains the PET -/// message to handle. -#[derive(From, Debug)] -pub struct PreProcessorRequest(MessageOwned); +/// Request type for [`PreProcessorService`] +pub type PreProcessorRequest = Request; /// Response type for [`PreProcessorService`] pub type PreProcessorResponse = Result; -impl Service> for PreProcessorService { +impl Service for PreProcessorService { type Response = PreProcessorResponse; type Error = std::convert::Infallible; @@ -82,8 +79,8 @@ impl Service> for PreProcessorService { } } - fn call(&mut self, req: Traced) -> Self::Future { - let MessageOwned { header, payload } = req.into_inner().0; + fn call(&mut self, req: PreProcessorRequest) -> Self::Future { + let MessageOwned { header, payload } = req.into_inner(); match (self.latest_phase_event.event, payload) { (PhaseName::Sum, PayloadOwned::Sum(sum)) => { let req = (header, sum, self.params_listener.get_latest().event); diff --git a/rust/src/services/messages/state_machine.rs b/rust/src/services/messages/state_machine.rs index e747ca619..5133a4a90 100644 --- a/rust/src/services/messages/state_machine.rs +++ b/rust/src/services/messages/state_machine.rs @@ -1,118 +1,37 @@ use std::{pin::Pin, task::Poll}; -use derive_more::From; use futures::{future::Future, task::Context}; -use thiserror::Error; -use tokio::sync::oneshot; use tower::Service; -use tracing::Span; use crate::{ - message::{ - message::MessageOwned, - payload::{update::UpdateOwned, PayloadOwned}, - }, - state_machine::requests::{ - Request, - RequestSender, - Sum2Request, - Sum2Response, - SumRequest, - SumResponse, - UpdateRequest, - UpdateResponse, - }, - utils::trace::{Traceable, Traced}, - PetError, + message::message::MessageOwned, + state_machine::{requests::RequestSender, StateMachineResult}, + utils::Request, }; -/// [`StateMachineService`] request type -#[derive(Debug, From)] -pub struct StateMachineRequest(MessageOwned); - -/// [`StateMachineService`] response type -pub type StateMachineResponse = Result<(), StateMachineError>; - -/// [`StateMachineService`] error type -#[derive(Debug, Error)] -pub enum StateMachineError { - #[error("PET protocol error: {0}")] - Pet(PetError), - - #[error("Unknown internal error")] - InternalError, -} +pub use crate::state_machine::{StateMachineError, StateMachineResult as StateMachineResponse}; /// A service that hands the requests to the state machine /// ([`crate::state_machine::StateMachine`]) that runs in the /// background. pub struct StateMachineService { - handle: RequestSender>, + handle: RequestSender, } impl StateMachineService { /// Create a new service with the given handle for forwarding /// requests to the state machine. The handle should be obtained /// via [`crate::state_machine::StateMachine::new`] - pub fn new(handle: RequestSender>) -> Self { + pub fn new(handle: RequestSender) -> Self { Self { handle } } - - fn handler(&self) -> StateMachineRequestHandler { - trace!("creating new handler"); - StateMachineRequestHandler { - handle: self.handle.clone(), - } - } -} - -struct StateMachineRequestHandler { - handle: RequestSender>, } -impl StateMachineRequestHandler { - fn send_request(&mut self, span: Span, req: Request) -> Result<(), StateMachineError> { - let req = Traced::new(req, span); - self.handle.send(req).map_err(|e| { - warn!("could not send request to the state machine: {:?}", e); - StateMachineError::InternalError - })?; - Ok(()) - } - - async fn sum_request(mut self, span: Span, req: SumRequest) -> StateMachineResponse { - let (resp_tx, resp_rx) = oneshot::channel::(); - self.send_request(span, Request::Sum((req, resp_tx)))?; - let sum_resp = resp_rx.await.map_err(|_| { - warn!("could not get response from state machine"); - StateMachineError::InternalError - })?; - sum_resp.map_err(StateMachineError::Pet) - } +/// Request type for [`StateMachineService`] +pub type StateMachineRequest = Request; - async fn update_request(mut self, span: Span, req: UpdateRequest) -> StateMachineResponse { - let (resp_tx, resp_rx) = oneshot::channel::(); - self.send_request(span, Request::Update((req, resp_tx)))?; - let update_resp = resp_rx.await.map_err(|_| { - warn!("could not get response from state machine"); - StateMachineError::InternalError - })?; - update_resp.map_err(StateMachineError::Pet) - } - - async fn sum2_request(mut self, span: Span, req: Sum2Request) -> StateMachineResponse { - let (resp_tx, resp_rx) = oneshot::channel::(); - self.send_request(span, Request::Sum2((req, resp_tx)))?; - let sum2_resp = resp_rx.await.map_err(|_| { - warn!("could not get response from state machine"); - StateMachineError::InternalError - })?; - sum2_resp.map_err(StateMachineError::Pet) - } -} - -impl Service> for StateMachineService { - type Response = StateMachineResponse; +impl Service for StateMachineService { + type Response = StateMachineResult; type Error = ::std::convert::Infallible; #[allow(clippy::type_complexity)] type Future = @@ -122,44 +41,8 @@ impl Service> for StateMachineService { Poll::Ready(Ok(())) } - fn call(&mut self, req: Traced) -> Self::Future { - trace!("creating a new handler for the request"); - let handler = self.handler(); - let req_span = req.span().clone(); - - let MessageOwned { header, payload } = req.into_inner().0; - - match payload { - PayloadOwned::Sum(sum) => { - debug!("creating a sum request to send to the state machine"); - let req = SumRequest { - participant_pk: header.participant_pk, - ephm_pk: sum.ephm_pk, - }; - Box::pin(async move { Ok(handler.sum_request(req_span, req).await) }) - } - PayloadOwned::Update(update) => { - debug!("creating an update request to send to the state machine"); - let UpdateOwned { - local_seed_dict, - masked_model, - .. - } = update; - let req = UpdateRequest { - participant_pk: header.participant_pk, - local_seed_dict, - masked_model, - }; - Box::pin(async move { Ok(handler.update_request(req_span, req).await) }) - } - PayloadOwned::Sum2(sum2) => { - debug!("creating a sum2 request to send to the state machine"); - let req = Sum2Request { - participant_pk: header.participant_pk, - mask: sum2.mask, - }; - Box::pin(async move { Ok(handler.sum2_request(req_span, req).await) }) - } - } + fn call(&mut self, req: StateMachineRequest) -> Self::Future { + let handle = self.handle.clone(); + Box::pin(async move { Ok(handle.request(req).await) }) } } diff --git a/rust/src/services/mod.rs b/rust/src/services/mod.rs index bc0d97033..3f2920a67 100644 --- a/rust/src/services/mod.rs +++ b/rust/src/services/mod.rs @@ -48,11 +48,7 @@ use crate::{ StateMachineService, }, }, - state_machine::{ - events::EventSubscriber, - requests::{Request, RequestSender}, - }, - utils::trace::Traced, + state_machine::{events::EventSubscriber, requests::RequestSender}, }; use std::sync::Arc; @@ -105,8 +101,8 @@ pub fn fetcher(event_subscriber: &EventSubscriber) -> impl Fetcher + Sync + Send /// Construct a [`PetMessageHandler`] service pub fn message_handler( event_subscriber: &EventSubscriber, - requests_tx: RequestSender>, -) -> impl PetMessageHandler + Sync + Send + 'static { + requests_tx: RequestSender, +) -> impl PetMessageHandler + Sync + Send + 'static + Clone { // TODO: make this configurable. Users should be able to // choose how many threads they want etc. // diff --git a/rust/src/services/tests/messages/message_parser.rs b/rust/src/services/tests/messages/message_parser.rs index 313361548..a5b687242 100644 --- a/rust/src/services/tests/messages/message_parser.rs +++ b/rust/src/services/tests/messages/message_parser.rs @@ -20,7 +20,7 @@ use crate::{ events::{EventPublisher, EventSubscriber}, phases::PhaseName, }, - utils::trace::Traced, + utils::Request, }; fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { @@ -30,8 +30,8 @@ fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) (publisher, subscriber, task) } -fn make_req(bytes: Vec) -> Traced { - Traced::new(bytes.into(), error_span!("test")) +fn make_req(bytes: Vec) -> MessageParserRequest> { + Request::new(bytes.into()) } fn new_sum_message(round_params: &RoundParameters) -> (MessageOwned, Vec) { @@ -41,10 +41,14 @@ fn new_sum_message(round_params: &RoundParameters) -> (MessageOwned, Vec) { (message, encrypted_message) } +fn assert_ready(task: &mut Spawn) { + assert_ready!(task.poll_ready::>>()).unwrap(); +} + #[tokio::test] async fn test_decrypt_fail() { let (_publisher, _subscriber, mut task) = spawn_svc(); - assert_ready!(task.poll_ready()).unwrap(); + assert_ready(&mut task); let req = make_req(vec![0, 1, 2, 3, 4, 5, 6]); let resp: Result = task.call(req).await; @@ -54,13 +58,13 @@ async fn test_decrypt_fail() { Ok(Err(MessageParserError::Decrypt)) => {} _ => panic!("expected decrypt error"), } - assert_ready!(task.poll_ready()).unwrap(); + assert_ready(&mut task); } #[tokio::test] async fn test_valid_request() { let (mut publisher, subscriber, mut task) = spawn_svc(); - assert_ready!(task.poll_ready()).unwrap(); + assert_ready(&mut task); let round_params = subscriber.params_listener().get_latest().event; let (message, encrypted_message) = new_sum_message(&round_params); @@ -78,7 +82,7 @@ async fn test_valid_request() { #[tokio::test] async fn test_unexpected_message() { let (_publisher, subscriber, mut task) = spawn_svc(); - assert_ready!(task.poll_ready()).unwrap(); + assert_ready(&mut task); let round_params = subscriber.params_listener().get_latest().event; let (_, encrypted_message) = new_sum_message(&round_params); diff --git a/rust/src/services/tests/messages/pre_processor.rs b/rust/src/services/tests/messages/pre_processor.rs index e55452b80..be1c18520 100644 --- a/rust/src/services/tests/messages/pre_processor.rs +++ b/rust/src/services/tests/messages/pre_processor.rs @@ -11,7 +11,7 @@ use crate::{ events::{EventPublisher, EventSubscriber}, phases::PhaseName, }, - utils::trace::Traced, + utils::Request, }; fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { @@ -20,8 +20,8 @@ fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) (publisher, subscriber, task) } -fn make_req(message: MessageOwned) -> Traced { - Traced::new(message.into(), error_span!("test")) +fn make_req(message: MessageOwned) -> PreProcessorRequest { + Request::new(message) } #[tokio::test] diff --git a/rust/src/state_machine/mod.rs b/rust/src/state_machine/mod.rs index 0d585efaf..7e327d5a5 100644 --- a/rust/src/state_machine/mod.rs +++ b/rust/src/state_machine/mod.rs @@ -117,16 +117,27 @@ use crate::{ state_machine::{ coordinator::CoordinatorState, events::EventSubscriber, - phases::{Idle, Phase, PhaseState, Purge, Shutdown, StateError, Sum, Sum2, Unmask, Update}, - requests::{Request, RequestReceiver, RequestSender}, + phases::{Idle, Phase, PhaseState, Shutdown, StateError, Sum, Sum2, Unmask, Update}, + requests::{RequestReceiver, RequestSender}, }, - utils::trace::Traced, InitError, + PetError, }; use derive_more::From; use thiserror::Error; +/// Error returned when the state machine fails to handle a request +#[derive(Debug, Error)] +pub enum StateMachineError { + #[error("the request failed")] + RequestFailed(#[from] PetError), + #[error("the request could not be processed due to an internal error")] + InternalError, +} + +pub type StateMachineResult = Result<(), StateMachineError>; + /// Error that occurs when unmasking of the global model fails. #[derive(Error, Debug, Eq, PartialEq)] pub enum RoundFailed { @@ -140,28 +151,25 @@ pub enum RoundFailed { /// The state machine with all its states. #[derive(From)] -pub enum StateMachine { - Idle(PhaseState), - Sum(PhaseState), - Update(PhaseState), - Sum2(PhaseState), - Unmask(PhaseState), - Error(PhaseState), - Shutdown(PhaseState), +pub enum StateMachine { + Idle(PhaseState), + Sum(PhaseState), + Update(PhaseState), + Sum2(PhaseState), + Unmask(PhaseState), + Error(PhaseState), + Shutdown(PhaseState), } -/// A [`StateMachine`] that processes `Traced`. -pub type TracingStateMachine = StateMachine>; - -impl StateMachine +impl StateMachine where - PhaseState: Phase + Purge, - PhaseState: Phase + Purge, - PhaseState: Phase + Purge, - PhaseState: Phase + Purge, - PhaseState: Phase + Purge, - PhaseState: Phase + Purge, - PhaseState: Phase + Purge, + PhaseState: Phase, + PhaseState: Phase, + PhaseState: Phase, + PhaseState: Phase, + PhaseState: Phase, + PhaseState: Phase, + PhaseState: Phase, { /// Creates a new state machine with the initial state [`Idle`]. /// @@ -175,7 +183,7 @@ where ///
///
     ///     Note: If the StateMachine is created via
-    ///     PhaseState::::new(...) it must be ensured that the module
+    ///     PhaseState::::new(...) it must be ensured that the module
     ///     
     ///     sodiumoxide::init() has been initialized beforehand.
     /// 
@@ -184,21 +192,21 @@ where /// ```compile_fail /// sodiumoxide::init().unwrap(); /// let state_machine = - /// StateMachine::from(PhaseState::::new(coordinator_state, req_receiver)); + /// StateMachine::from(PhaseState::::new(coordinator_state, req_receiver)); /// ``` pub fn new( pet_settings: PetSettings, mask_settings: MaskSettings, model_settings: ModelSettings, - ) -> Result<(Self, RequestSender, EventSubscriber), InitError> { + ) -> Result<(Self, RequestSender, EventSubscriber), InitError> { // crucial: init must be called before anything else in this module sodiumoxide::init().or(Err(InitError))?; let (coordinator_state, event_subscriber) = CoordinatorState::new(pet_settings, mask_settings, model_settings); - let (req_receiver, handle) = RequestReceiver::::new(); + let (req_receiver, handle) = RequestReceiver::new(); let state_machine = - StateMachine::from(PhaseState::::new(coordinator_state, req_receiver)); + StateMachine::from(PhaseState::::new(coordinator_state, req_receiver)); Ok((state_machine, handle, event_subscriber)) } diff --git a/rust/src/state_machine/phases/error.rs b/rust/src/state_machine/phases/error.rs index cb293aee9..ed2019efb 100644 --- a/rust/src/state_machine/phases/error.rs +++ b/rust/src/state_machine/phases/error.rs @@ -18,11 +18,11 @@ pub enum StateError { TimeoutError(#[from] tokio::time::Elapsed), } -impl PhaseState { +impl PhaseState { /// Creates a new error state. pub fn new( coordinator_state: CoordinatorState, - request_rx: RequestReceiver, + request_rx: RequestReceiver, error: StateError, ) -> Self { info!("state transition"); @@ -35,10 +35,7 @@ impl PhaseState { } #[async_trait] -impl Phase for PhaseState -where - R: Send, -{ +impl Phase for PhaseState { const NAME: PhaseName = PhaseName::Error; async fn run(&mut self) -> Result<(), StateError> { @@ -55,12 +52,12 @@ where /// Moves from the error state to the next state. /// /// See the [module level documentation](../index.html) for more details. - fn next(self) -> Option> { + fn next(self) -> Option { Some(match self.inner { StateError::ChannelError(_) => { - PhaseState::::new(self.coordinator_state, self.request_rx).into() + PhaseState::::new(self.coordinator_state, self.request_rx).into() } - _ => PhaseState::::new(self.coordinator_state, self.request_rx).into(), + _ => PhaseState::::new(self.coordinator_state, self.request_rx).into(), }) } } diff --git a/rust/src/state_machine/phases/idle.rs b/rust/src/state_machine/phases/idle.rs index cfaac6f88..835c3478c 100644 --- a/rust/src/state_machine/phases/idle.rs +++ b/rust/src/state_machine/phases/idle.rs @@ -3,11 +3,12 @@ use crate::{ state_machine::{ coordinator::{CoordinatorState, RoundSeed}, events::{DictionaryUpdate, MaskLengthUpdate, ScalarUpdate}, - phases::{reject_request, Handler, Phase, PhaseName, PhaseState, Sum}, - requests::{Request, RequestReceiver}, + phases::{Handler, Phase, PhaseName, PhaseState, Sum}, + requests::{RequestReceiver, StateMachineRequest}, StateError, StateMachine, }, + PetError, }; use sodiumoxide::crypto::hash::sha256; @@ -16,18 +17,15 @@ use sodiumoxide::crypto::hash::sha256; #[derive(Debug)] pub struct Idle; -impl Handler for PhaseState { - /// Reject all the request with a [`PetError::InvalidMessage`] - fn handle_request(&mut self, req: Request) { - reject_request(req); +impl Handler for PhaseState { + /// Reject the request with a [`PetError::InvalidMessage`] + fn handle_request(&mut self, _req: StateMachineRequest) -> Result<(), PetError> { + Err(PetError::InvalidMessage) } } #[async_trait] -impl Phase for PhaseState -where - R: Send, -{ +impl Phase for PhaseState { const NAME: PhaseName = PhaseName::Idle; /// Moves from the idle state to the next state. @@ -67,15 +65,15 @@ where Ok(()) } - fn next(self) -> Option> { + fn next(self) -> Option { info!("going to sum phase"); - Some(PhaseState::::new(self.coordinator_state, self.request_rx).into()) + Some(PhaseState::::new(self.coordinator_state, self.request_rx).into()) } } -impl PhaseState { +impl PhaseState { /// Creates a new idle state. - pub fn new(mut coordinator_state: CoordinatorState, request_rx: RequestReceiver) -> Self { + pub fn new(mut coordinator_state: CoordinatorState, request_rx: RequestReceiver) -> Self { // Since some events are emitted very early, the round id must // be correct when the idle phase starts. Therefore, we update // it here, when instantiating the idle PhaseState. @@ -135,8 +133,8 @@ mod test { let id = keys.get_latest().round_id; assert_eq!(id, 0); - let (request_rx, _request_tx) = RequestReceiver::::new(); - let mut idle_phase = PhaseState::::new(coordinator_state, request_rx); + let (request_rx, _request_tx) = RequestReceiver::new(); + let mut idle_phase = PhaseState::::new(coordinator_state, request_rx); idle_phase.run().await.unwrap(); let id = keys.get_latest().round_id; diff --git a/rust/src/state_machine/phases/mod.rs b/rust/src/state_machine/phases/mod.rs index 993292cf6..2c88d1ba4 100644 --- a/rust/src/state_machine/phases/mod.rs +++ b/rust/src/state_machine/phases/mod.rs @@ -21,15 +21,14 @@ pub use self::{ use crate::{ state_machine::{ coordinator::CoordinatorState, - requests::{Request, RequestReceiver}, + requests::{RequestReceiver, ResponseSender, StateMachineRequest}, StateMachine, }, - utils::trace::{Traceable, Traced}, + utils::Request, PetError, }; use futures::StreamExt; -use tokio::sync::oneshot; use tracing_futures::Instrument; /// Name of the current phase @@ -46,7 +45,7 @@ pub enum PhaseName { /// A trait that must be implemented by a state in order to move to a next state. #[async_trait] -pub trait Phase { +pub trait Phase { /// Name of the current phase const NAME: PhaseName; @@ -54,68 +53,31 @@ pub trait Phase { async fn run(&mut self) -> Result<(), StateError>; /// Moves from this state to the next state. - fn next(self) -> Option>; + fn next(self) -> Option; } /// A trait that must be implemented by a state to handle a request. -pub trait Handler { +pub trait Handler { /// Handles a request. - fn handle_request(&mut self, req: R); -} - -/// When the state machine transitions to a new phase, all the pending -/// requests are considered outdated, and purged. The [`Purge`] trait -/// implements this behavior. -pub trait Purge { - /// Process an outdated request. - fn handle_outdated_request(&mut self, req: R); -} - -impl Purge for PhaseState { - fn handle_outdated_request(&mut self, req: Request) { - reject_request(req) - } -} - -impl Purge> for PhaseState -where - Self: Purge, -{ - fn handle_outdated_request(&mut self, req: Traced) { - let span = req.span().clone(); - let _enter = span.enter(); - >::handle_outdated_request(self, req.into_inner()) - } -} - -impl Handler> for PhaseState -where - Self: Handler, -{ - /// Handles a [`Request`]. - fn handle_request(&mut self, req: Traced) { - let span = req.span().clone(); - let _enter = span.enter(); - >::handle_request(self, req.into_inner()) - } + fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), PetError>; } /// The state corresponding to a phase of the PET protocol. /// /// This contains the state-dependent `inner` state and the state-independent `coordinator_state` /// which is shared across state transitions. -pub struct PhaseState { +pub struct PhaseState { /// The inner state. pub(in crate::state_machine) inner: S, /// The Coordinator state. pub(in crate::state_machine) coordinator_state: CoordinatorState, /// The request receiver half. - pub(in crate::state_machine) request_rx: RequestReceiver, + pub(in crate::state_machine) request_rx: RequestReceiver, } -impl PhaseState +impl PhaseState where - Self: Handler + Phase + Purge, + Self: Handler + Phase, { /// Processes requests for as long as the given duration. async fn process_during(&mut self, dur: tokio::time::Duration) -> Result<(), StateError> { @@ -140,20 +102,25 @@ where /// Processes the next available request. async fn process_single(&mut self) -> Result<(), StateError> { - let req = self.next_request().await?; - self.handle_request(req); + let (req, resp_tx) = self.next_request().await?; + let span = req.span(); + let _span_guard = span.enter(); + let res = self.handle_request(req.into_inner()); + // This may error out if the receiver has already be dropped but + // it doesn't matter for us. + let _ = resp_tx.send(res.map_err(Into::into)); Ok(()) } } -impl PhaseState +impl PhaseState where - Self: Phase + Purge, + Self: Phase, { /// Run the current phase to completion, then transition to the /// next phase and return it. - pub async fn run_phase(mut self) -> Option> { - let phase = >::NAME; + pub async fn run_phase(mut self) -> Option { + let phase = ::NAME; let span = error_span!("run_phase", phase = ?phase); async move { @@ -194,7 +161,12 @@ where fn purge_outdated_requests(&mut self) -> Result<(), StateError> { loop { match self.try_next_request()? { - Some(req) => self.handle_outdated_request(req), + Some((req, resp_tx)) => { + let span = req.span(); + let _span_guard = span.enter(); + info!("rejecting request"); + let _ = resp_tx.send(Err(PetError::InvalidMessage.into())); + } None => return Ok(()), } } @@ -202,12 +174,14 @@ where } // Functions that are available to all states -impl PhaseState { +impl PhaseState { /// Receives the next [`Request`]. /// /// # Errors /// Returns [`StateError::ChannelError`] when all sender halves have been dropped. - async fn next_request(&mut self) -> Result { + async fn next_request( + &mut self, + ) -> Result<(Request, ResponseSender), StateError> { debug!("waiting for the next incoming request"); self.request_rx.next().await.ok_or_else(|| { error!("request receiver broken: senders have been dropped"); @@ -215,9 +189,11 @@ impl PhaseState { }) } - fn try_next_request(&mut self) -> Result, StateError> { + fn try_next_request( + &mut self, + ) -> Result, ResponseSender)>, StateError> { match self.request_rx.try_recv() { - Ok(req) => Ok(Some(req)), + Ok(item) => Ok(Some(item)), Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { debug!("no pending request"); Ok(None) @@ -231,26 +207,7 @@ impl PhaseState { } } - fn into_error_state(self, err: StateError) -> StateMachine { - PhaseState::::new(self.coordinator_state, self.request_rx, err).into() + fn into_error_state(self, err: StateError) -> StateMachine { + PhaseState::::new(self.coordinator_state, self.request_rx, err).into() } } - -/// Respond to the given request with a rejection error. -pub fn reject_request(req: Request) { - match req { - Request::Sum((_, response_tx)) => send_rejection(response_tx), - Request::Update((_, response_tx)) => send_rejection(response_tx), - Request::Sum2((_, response_tx)) => send_rejection(response_tx), - } -} - -/// Send a rejection through the given channel -fn send_rejection(response_tx: oneshot::Sender>) { - debug!("invalid message"); - // `send` returns an error if the receiver half has already - // been dropped. This means that the receiver is not - // interested in the response of the request. Therefore the - // error is ignored. - let _ = response_tx.send(Err(PetError::InvalidMessage)); -} diff --git a/rust/src/state_machine/phases/shutdown.rs b/rust/src/state_machine/phases/shutdown.rs index 7b50e8762..d56d94bfb 100644 --- a/rust/src/state_machine/phases/shutdown.rs +++ b/rust/src/state_machine/phases/shutdown.rs @@ -11,10 +11,7 @@ use crate::state_machine::{ pub struct Shutdown; #[async_trait] -impl Phase for PhaseState -where - R: Send, -{ +impl Phase for PhaseState { const NAME: PhaseName = PhaseName::Shutdown; /// Shuts down the [`StateMachine`]. @@ -27,14 +24,14 @@ where Ok(()) } - fn next(self) -> Option> { + fn next(self) -> Option { None } } -impl PhaseState { +impl PhaseState { /// Creates a new shutdown state. - pub fn new(coordinator_state: CoordinatorState, request_rx: RequestReceiver) -> Self { + pub fn new(coordinator_state: CoordinatorState, request_rx: RequestReceiver) -> Self { info!("state transition"); Self { inner: Shutdown, diff --git a/rust/src/state_machine/phases/sum.rs b/rust/src/state_machine/phases/sum.rs index 4af9896f9..1520762ce 100644 --- a/rust/src/state_machine/phases/sum.rs +++ b/rust/src/state_machine/phases/sum.rs @@ -4,28 +4,17 @@ use crate::{ state_machine::{ coordinator::CoordinatorState, events::DictionaryUpdate, - phases::{ - reject_request, - Handler, - Phase, - PhaseName, - PhaseState, - Purge, - StateError, - Update, - }, - requests::{Request, RequestReceiver, SumRequest, SumResponse}, + phases::{Handler, Phase, PhaseName, PhaseState, StateError, Update}, + requests::{RequestReceiver, StateMachineRequest, SumRequest}, StateMachine, }, LocalSeedDict, + PetError, SeedDict, SumDict, }; -use tokio::{ - sync::oneshot, - time::{timeout, Duration}, -}; +use tokio::time::{timeout, Duration}; /// Sum state #[derive(Debug)] @@ -43,26 +32,24 @@ impl Sum { } } -impl Handler for PhaseState { - /// Handles a [`Request::Sum`], [`Request::Update`] or [`Request::Sum2`] request.\ - /// - /// If the request is a [`Request::Update`] or [`Request::Sum2`] request, the request sender - /// will receive a [`PetError::InvalidMessage`]. +impl Handler for PhaseState { + /// Handles a [`StateMachineRequest`]. /// - /// [`PetError::InvalidMessage`]: crate::PetError::InvalidMessage - fn handle_request(&mut self, req: Request) { + /// If the request is a [`StateMachineRequest::Update`] or + /// [`StateMachineRequest::Sum2`] request, the request sender will receive a + /// [`PetError::InvalidMessage`]. + fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), PetError> { match req { - Request::Sum((sum_req, response_tx)) => self.handle_sum(sum_req, response_tx), - _ => reject_request(req), + StateMachineRequest::Sum(sum_req) => self.handle_sum(sum_req), + _ => Err(PetError::InvalidMessage), } } } #[async_trait] -impl Phase for PhaseState +impl Phase for PhaseState where - Self: Handler + Purge, - R: Send, + Self: Handler, { const NAME: PhaseName = PhaseName::Sum; @@ -86,7 +73,7 @@ where Ok(()) } - fn next(self) -> Option> { + fn next(self) -> Option { let Self { inner: Sum { sum_dict, @@ -96,7 +83,7 @@ where request_rx, } = self; Some( - PhaseState::::new( + PhaseState::::new( coordinator_state, request_rx, sum_dict, @@ -110,9 +97,9 @@ where } } -impl PhaseState +impl PhaseState where - Self: Handler + Phase + Purge, + Self: Handler + Phase, { /// Processes requests until there are enough. async fn process_until_enough(&mut self) -> Result<(), StateError> { @@ -128,9 +115,9 @@ where } } -impl PhaseState { +impl PhaseState { /// Creates a new sum state. - pub fn new(coordinator_state: CoordinatorState, request_rx: RequestReceiver) -> Self { + pub fn new(coordinator_state: CoordinatorState, request_rx: RequestReceiver) -> Self { info!("state transition"); Self { inner: Sum { @@ -143,14 +130,13 @@ impl PhaseState { } /// Handles a sum request. - fn handle_sum(&mut self, req: SumRequest, response_tx: oneshot::Sender) { + fn handle_sum(&mut self, req: SumRequest) -> Result<(), PetError> { let SumRequest { participant_pk, ephm_pk, } = req; - self.inner.sum_dict.insert(participant_pk, ephm_pk); - let _ = response_tx.send(Ok(())); + Ok(()) } /// Freezes the sum dictionary. @@ -194,7 +180,7 @@ mod test { sum_dict: SumDict::new(), seed_dict: None, }; - let (state_machine, mut request_tx, events) = StateMachineBuilder::new() + let (state_machine, request_tx, events) = StateMachineBuilder::new() .with_phase(sum) // Make sure anyone is a sum participant. .with_sum_ratio(1.0) @@ -216,7 +202,7 @@ mod test { // update phase let mut summer = generate_summer(&seed, 1.0, 0.0); let sum_msg = summer.compose_sum_message(&keys.public); - let request_fut = async { request_tx.sum(&sum_msg).await.unwrap() }; + let request_fut = async { request_tx.msg(&sum_msg).await.unwrap() }; let transition_fut = async { state_machine.next().await.unwrap() }; let (_response, state_machine) = tokio::join!(request_fut, transition_fut); diff --git a/rust/src/state_machine/phases/sum2.rs b/rust/src/state_machine/phases/sum2.rs index 334d02e7a..df9cf3811 100644 --- a/rust/src/state_machine/phases/sum2.rs +++ b/rust/src/state_machine/phases/sum2.rs @@ -2,17 +2,8 @@ use crate::{ mask::{masking::Aggregation, object::MaskObject}, state_machine::{ coordinator::{CoordinatorState, MaskDict}, - phases::{ - reject_request, - Handler, - Phase, - PhaseName, - PhaseState, - Purge, - StateError, - Unmask, - }, - requests::{Request, RequestReceiver, Sum2Request, Sum2Response}, + phases::{Handler, Phase, PhaseName, PhaseState, StateError, Unmask}, + requests::{RequestReceiver, StateMachineRequest, Sum2Request}, StateMachine, }, PetError, @@ -20,10 +11,7 @@ use crate::{ SumParticipantPublicKey, }; -use tokio::{ - sync::oneshot, - time::{timeout, Duration}, -}; +use tokio::time::{timeout, Duration}; /// Sum2 state #[derive(Debug)] @@ -54,10 +42,9 @@ impl Sum2 { } #[async_trait] -impl Phase for PhaseState +impl Phase for PhaseState where - Self: Purge + Handler, - R: Send, + Self: Handler, { const NAME: PhaseName = PhaseName::Sum2; @@ -83,9 +70,9 @@ where /// Moves from the sum2 state to the next state. /// /// See the [module level documentation](../index.html) for more details. - fn next(self) -> Option> { + fn next(self) -> Option { Some( - PhaseState::::new( + PhaseState::::new( self.coordinator_state, self.request_rx, self.inner.aggregation, @@ -96,9 +83,9 @@ where } } -impl PhaseState +impl PhaseState where - Self: Handler + Phase + Purge, + Self: Handler + Phase, { /// Processes requests until there are enough. async fn process_until_enough(&mut self) -> Result<(), StateError> { @@ -114,24 +101,25 @@ where } } -impl Handler for PhaseState { - /// Handles a [`Request::Sum`], [`Request::Update`] or [`Request::Sum2`] request. +impl Handler for PhaseState { + /// Handles a [`StateMachineRequest`], /// - /// If the request is a [`Request::Sum`] or [`Request::Update`] request, the request sender + /// If the request is a [`StateMachineRequest::Sum`] or + /// [`StateMachineRequest::Update`] request, the request sender /// will receive a [`PetError::InvalidMessage`]. - fn handle_request(&mut self, req: Request) { + fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), PetError> { match req { - Request::Sum2((sum2_req, response_tx)) => self.handle_sum2(sum2_req, response_tx), - _ => reject_request(req), + StateMachineRequest::Sum2(sum2_req) => self.handle_sum2(sum2_req), + _ => Err(PetError::InvalidMessage), } } } -impl PhaseState { +impl PhaseState { /// Creates a new sum2 state. pub fn new( coordinator_state: CoordinatorState, - request_rx: RequestReceiver, + request_rx: RequestReceiver, sum_dict: SumDict, aggregation: Aggregation, ) -> Self { @@ -149,14 +137,12 @@ impl PhaseState { /// Handles a sum2 request. /// If the handling of the sum2 message fails, an error is returned to the request sender. - fn handle_sum2(&mut self, req: Sum2Request, response_tx: oneshot::Sender) { + fn handle_sum2(&mut self, req: Sum2Request) -> Result<(), PetError> { let Sum2Request { participant_pk, mask, } = req; - - // See `Self::handle_invalid_message` - let _ = response_tx.send(self.add_mask(&participant_pk, mask)); + self.add_mask(&participant_pk, mask) } /// Adds a mask to the mask dictionary. @@ -191,6 +177,7 @@ impl PhaseState { #[cfg(test)] mod test { + use super::*; use crate::{ crypto::{ByteObject, EncryptKeyPair}, @@ -258,7 +245,7 @@ mod test { .unwrap(); // Have the state machine process the request - let req = async { request_tx.clone().sum2(&msg).await.unwrap() }; + let req = async { request_tx.msg(&msg).await.unwrap() }; let transition = async { state_machine.next().await.unwrap() }; let ((), state_machine) = tokio::join!(req, transition); assert!(state_machine.is_unmask()); diff --git a/rust/src/state_machine/phases/unmask.rs b/rust/src/state_machine/phases/unmask.rs index dd4ecc057..5268489b1 100644 --- a/rust/src/state_machine/phases/unmask.rs +++ b/rust/src/state_machine/phases/unmask.rs @@ -33,10 +33,7 @@ impl Unmask { } #[async_trait] -impl Phase for PhaseState -where - R: Send, -{ +impl Phase for PhaseState { const NAME: PhaseName = PhaseName::Unmask; /// Run the unmasking phase @@ -54,17 +51,17 @@ where /// Moves from the unmask state to the next state. /// /// See the [module level documentation](../index.html) for more details. - fn next(self) -> Option> { + fn next(self) -> Option { info!("going back to idle phase"); - Some(PhaseState::::new(self.coordinator_state, self.request_rx).into()) + Some(PhaseState::::new(self.coordinator_state, self.request_rx).into()) } } -impl PhaseState { +impl PhaseState { /// Creates a new unmask state. pub fn new( coordinator_state: CoordinatorState, - request_rx: RequestReceiver, + request_rx: RequestReceiver, aggregation: Aggregation, mask_dict: MaskDict, ) -> Self { diff --git a/rust/src/state_machine/phases/update.rs b/rust/src/state_machine/phases/update.rs index 1c531e197..acac7bd1f 100644 --- a/rust/src/state_machine/phases/update.rs +++ b/rust/src/state_machine/phases/update.rs @@ -5,8 +5,8 @@ use crate::{ state_machine::{ coordinator::CoordinatorState, events::{DictionaryUpdate, MaskLengthUpdate, ScalarUpdate}, - phases::{reject_request, Handler, Phase, PhaseName, PhaseState, Purge, StateError, Sum2}, - requests::{Request, RequestReceiver, UpdateRequest, UpdateResponse}, + phases::{Handler, Phase, PhaseName, PhaseState, StateError, Sum2}, + requests::{RequestReceiver, StateMachineRequest, UpdateRequest}, StateMachine, }, LocalSeedDict, @@ -16,10 +16,7 @@ use crate::{ UpdateParticipantPublicKey, }; -use tokio::{ - sync::oneshot, - time::{timeout, Duration}, -}; +use tokio::time::{timeout, Duration}; /// Update state #[derive(Debug)] @@ -48,10 +45,9 @@ impl Update { } #[async_trait] -impl Phase for PhaseState +impl Phase for PhaseState where - Self: Handler + Purge, - R: Send, + Self: Handler, { const NAME: PhaseName = PhaseName::Update; @@ -82,7 +78,7 @@ where Ok(()) } - fn next(self) -> Option> { + fn next(self) -> Option { let PhaseState { inner: Update { @@ -105,15 +101,15 @@ where .broadcast_seed_dict(DictionaryUpdate::New(Arc::new(seed_dict))); Some( - PhaseState::::new(coordinator_state, request_rx, frozen_sum_dict, aggregation) + PhaseState::::new(coordinator_state, request_rx, frozen_sum_dict, aggregation) .into(), ) } } -impl PhaseState +impl PhaseState where - Self: Handler + Phase + Purge, + Self: Handler + Phase, { /// Processes requests until there are enough. async fn process_until_enough(&mut self) -> Result<(), StateError> { @@ -129,26 +125,25 @@ where } } -impl Handler for PhaseState { - /// Handles a [`Request::Sum`], [`Request::Update`] or [`Request::Sum2`] request. +impl Handler for PhaseState { + /// Handles a [`StateMachineRequest`]. /// - /// If the request is a [`Request::Sum`] or [`Request::Sum2`] request, the request sender - /// will receive a [`PetError::InvalidMessage`]. - fn handle_request(&mut self, req: Request) { + /// If the request is a [`StateMachineRequest::Sum`] or + /// [`StateMachineRequest::Sum2`] request, the request sender will + /// receive a [`PetError::InvalidMessage`]. + fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), PetError> { match req { - Request::Update((update_req, response_tx)) => { - self.handle_update(update_req, response_tx) - } - _ => reject_request(req), + StateMachineRequest::Update(update_req) => self.handle_update(update_req), + _ => Err(PetError::InvalidMessage), } } } -impl PhaseState { +impl PhaseState { /// Creates a new update state. pub fn new( coordinator_state: CoordinatorState, - request_rx: RequestReceiver, + request_rx: RequestReceiver, frozen_sum_dict: SumDict, seed_dict: SeedDict, ) -> Self { @@ -169,19 +164,13 @@ impl PhaseState { /// Handles an update request. /// If the handling of the update message fails, an error is returned to the request sender. - fn handle_update(&mut self, req: UpdateRequest, response_tx: oneshot::Sender) { + fn handle_update(&mut self, req: UpdateRequest) -> Result<(), PetError> { let UpdateRequest { participant_pk, local_seed_dict, masked_model, } = req; - - // See `handle_invalid_message` - let _ = response_tx.send(self.update_seed_dict_and_aggregate_mask( - &participant_pk, - &local_seed_dict, - masked_model, - )); + self.update_seed_dict_and_aggregate_mask(&participant_pk, &local_seed_dict, masked_model) } /// Updates the local seed dict and aggregates the masked model. @@ -320,7 +309,7 @@ mod test { }; // Create the state machine - let (state_machine, mut request_tx, events) = StateMachineBuilder::new() + let (state_machine, request_tx, events) = StateMachineBuilder::new() .with_seed(seed.clone()) .with_phase(update) .with_sum_ratio(sum_ratio) @@ -343,7 +332,7 @@ mod test { model.clone(), ); let masked_model = update_msg.masked_model(); - let request_fut = async { request_tx.update(&update_msg).await.unwrap() }; + let request_fut = async { request_tx.msg(&update_msg).await.unwrap() }; // Have the state machine process the request let transition_fut = async { state_machine.next().await.unwrap() }; diff --git a/rust/src/state_machine/requests.rs b/rust/src/state_machine/requests.rs index 7291047c0..988ba7f28 100644 --- a/rust/src/state_machine/requests.rs +++ b/rust/src/state_machine/requests.rs @@ -11,6 +11,7 @@ use derive_more::From; use futures::Stream; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; +use tracing::Span; /// Error that occurs when a [`RequestSender`] tries to send a request on a closed `Request` channel. #[derive(Debug, Error)] @@ -19,15 +20,18 @@ pub struct StateMachineShutdown; use crate::{ mask::object::MaskObject, + message::{MessageOwned, PayloadOwned, UpdateOwned}, + state_machine::{StateMachineError, StateMachineResult}, + utils::{Request, Traceable}, LocalSeedDict, ParticipantPublicKey, - PetError as Error, SumParticipantEphemeralPublicKey, SumParticipantPublicKey, UpdateParticipantPublicKey, }; /// A sum request. +#[derive(Debug)] pub struct SumRequest { /// The public key of the participant. pub participant_pk: SumParticipantPublicKey, @@ -36,6 +40,7 @@ pub struct SumRequest { } /// An update request. +#[derive(Debug)] pub struct UpdateRequest { /// The public key of the participant. pub participant_pk: UpdateParticipantPublicKey, @@ -46,6 +51,7 @@ pub struct UpdateRequest { } /// A sum2 request. +#[derive(Debug)] pub struct Sum2Request { /// The public key of the participant. pub participant_pk: ParticipantPublicKey, @@ -53,36 +59,62 @@ pub struct Sum2Request { pub mask: MaskObject, } -/// A sum response. -pub type SumResponse = Result<(), Error>; -/// An update response. -pub type UpdateResponse = Result<(), Error>; -/// A sum2 response. -pub type Sum2Response = Result<(), Error>; - /// A [`StateMachine`] request. /// /// [`StateMachine`]: crate::state_machine -pub enum Request { - Sum((SumRequest, oneshot::Sender)), - Update((UpdateRequest, oneshot::Sender)), - Sum2((Sum2Request, oneshot::Sender)), +#[derive(Debug, From)] +pub enum StateMachineRequest { + Sum(SumRequest), + Update(UpdateRequest), + Sum2(Sum2Request), } -/// A handle to send requests to the [`StateMachine`]. -/// -/// [`StateMachine`]: crate::state_machine -#[derive(From)] -pub struct RequestSender(mpsc::UnboundedSender); +impl Traceable for StateMachineRequest { + fn make_span(&self) -> Span { + let request_type = match self { + Self::Sum(_) => "sum", + Self::Update(_) => "update", + Self::Sum2(_) => "sum2", + }; + error_span!("StateMachineRequest", request_type = request_type) + } +} -impl Clone for RequestSender { - // Clones the sender half of the `Request` channel. - fn clone(&self) -> Self { - RequestSender(self.0.clone()) +impl From for StateMachineRequest { + fn from(message: MessageOwned) -> Self { + let MessageOwned { header, payload } = message; + match payload { + PayloadOwned::Sum(sum) => StateMachineRequest::Sum(SumRequest { + participant_pk: header.participant_pk, + ephm_pk: sum.ephm_pk, + }), + PayloadOwned::Update(update) => { + let UpdateOwned { + local_seed_dict, + masked_model, + .. + } = update; + StateMachineRequest::Update(UpdateRequest { + participant_pk: header.participant_pk, + local_seed_dict, + masked_model, + }) + } + PayloadOwned::Sum2(sum2) => StateMachineRequest::Sum2(Sum2Request { + participant_pk: header.participant_pk, + mask: sum2.mask, + }), + } } } -impl RequestSender { +/// A handle to send requests to the [`StateMachine`]. +/// +/// [`StateMachine`]: crate::state_machine +#[derive(Clone, From)] +pub struct RequestSender(mpsc::UnboundedSender<(Request, ResponseSender)>); + +impl RequestSender { /// Sends a request to the [`StateMachine`]. /// /// # Errors @@ -90,20 +122,37 @@ impl RequestSender { /// closed as a result. /// /// [`StateMachine`]: crate::state_machine - pub fn send(&self, req: R) -> Result<(), StateMachineShutdown> { - self.0.send(req).map_err(|_| StateMachineShutdown) + pub async fn request + Traceable>( + &self, + req: Request, + ) -> StateMachineResult { + let (resp_tx, resp_rx) = oneshot::channel::(); + self.0.send((req.map(Into::into), resp_tx)).map_err(|_| { + warn!("failed to send request to the state machine: state machine is shutting down"); + StateMachineError::InternalError + })?; + resp_rx.await.map_err(|_| { + warn!( + "failed to receive response from the state machine: state machine is shutting down" + ); + StateMachineError::InternalError + })? } } +/// A channel for sending the state machine to send the response to a +/// [`StateMachineRequest`]. +pub(in crate::state_machine) type ResponseSender = oneshot::Sender; + /// The receiver half of the `Request` channel that is used by the [`StateMachine`] to receive /// requests. /// /// [`StateMachine`]: crate::state_machine #[derive(From)] -pub struct RequestReceiver(mpsc::UnboundedReceiver); +pub struct RequestReceiver(mpsc::UnboundedReceiver<(Request, ResponseSender)>); -impl Stream for RequestReceiver { - type Item = R; +impl Stream for RequestReceiver { + type Item = (Request, ResponseSender); fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { trace!("RequestReceiver: polling"); @@ -111,11 +160,11 @@ impl Stream for RequestReceiver { } } -impl RequestReceiver { +impl RequestReceiver { /// Creates a new `Request` channel and returns the [`RequestReceiver`] as well as the /// [`RequestSender`] half. - pub fn new() -> (Self, RequestSender) { - let (tx, rx) = mpsc::unbounded_channel::(); + pub fn new() -> (Self, RequestSender) { + let (tx, rx) = mpsc::unbounded_channel::<(Request, ResponseSender)>(); let receiver = RequestReceiver::from(rx); let handle = RequestSender::from(tx); (receiver, handle) @@ -133,7 +182,7 @@ impl RequestReceiver { /// See [the `tokio` documentation][receive] for more information. /// /// [receive]: https://docs.rs/tokio/0.2.21/tokio/sync/mpsc/struct.UnboundedReceiver.html#method.recv - pub async fn recv(&mut self) -> Option { + pub async fn recv(&mut self) -> Option<(Request, ResponseSender)> { self.0.recv().await } @@ -141,23 +190,12 @@ impl RequestReceiver { /// See [the `tokio` documentation][try_receive] for more information. /// /// [try_receive]: https://docs.rs/tokio/0.2.21/tokio/sync/mpsc/struct.UnboundedReceiver.html#method.try_recv - pub fn try_recv(&mut self) -> Result { + pub fn try_recv( + &mut self, + ) -> Result< + (Request, ResponseSender), + tokio::sync::mpsc::error::TryRecvError, + > { self.0.try_recv() } } - -#[cfg(test)] -mod tests { - use super::*; - - fn drop(_t: T) {} - - #[tokio::test] - async fn test_channel() { - let (mut recv, snd) = RequestReceiver::<()>::new(); - snd.send(()).unwrap(); - recv.recv().await.unwrap(); - drop(snd); - assert!(recv.recv().await.is_none()); - } -} diff --git a/rust/src/state_machine/tests/builder.rs b/rust/src/state_machine/tests/builder.rs index 3b25a18e2..d7f9918fd 100644 --- a/rust/src/state_machine/tests/builder.rs +++ b/rust/src/state_machine/tests/builder.rs @@ -5,7 +5,7 @@ use crate::{ coordinator::{CoordinatorState, RoundSeed}, events::EventSubscriber, phases::{self, Handler, Phase, PhaseState}, - requests::{Request, RequestReceiver, RequestSender}, + requests::{RequestReceiver, RequestSender}, tests::utils, StateMachine, }, @@ -36,29 +36,23 @@ impl StateMachineBuilder { impl

StateMachineBuilder

where - PhaseState: Handler + Phase, - StateMachine: From>, + PhaseState

: Handler + Phase, + StateMachine: From>, { - pub fn build( - self, - ) -> ( - StateMachine, - RequestSender, - EventSubscriber, - ) { + pub fn build(self) -> (StateMachine, RequestSender, EventSubscriber) { let Self { mut coordinator_state, event_subscriber, phase_state, } = self; - let (request_rx, request_tx) = RequestReceiver::::new(); + let (request_rx, request_tx) = RequestReceiver::new(); // Make sure the events that the listeners have are up to date let events = &mut coordinator_state.events; events.broadcast_keys(coordinator_state.keys.clone()); events.broadcast_params(coordinator_state.round_params.clone()); - events.broadcast_phase( as Phase>::NAME); + events.broadcast_phase( as Phase>::NAME); // Also re-emit the other events in case the round ID changed let scalar = event_subscriber.scalar_listener().get_latest().event; events.broadcast_scalar(scalar); diff --git a/rust/src/state_machine/tests/impls.rs b/rust/src/state_machine/tests/impls.rs index 3f740cade..0db01ee0e 100644 --- a/rust/src/state_machine/tests/impls.rs +++ b/rust/src/state_machine/tests/impls.rs @@ -4,26 +4,23 @@ use crate::{ state_machine::{ events::{DictionaryUpdate, MaskLengthUpdate}, phases::{self, PhaseState}, - requests::{ - Request, - RequestSender, - Sum2Request, - Sum2Response, - SumRequest, - SumResponse, - UpdateRequest, - UpdateResponse, - }, + requests::RequestSender, StateMachine, + StateMachineResult, }, + utils::Request, LocalSeedDict, SumParticipantEphemeralPublicKey, SumParticipantPublicKey, }; -use tokio::sync::oneshot; +impl RequestSender { + pub async fn msg(&self, msg: &MessageOwned) -> StateMachineResult { + self.request(Request::new(msg.clone())).await + } +} -impl StateMachine { +impl StateMachine { pub fn is_update(&self) -> bool { match self { StateMachine::Update(_) => true, @@ -31,7 +28,7 @@ impl StateMachine { } } - pub fn into_update_phase_state(self) -> PhaseState { + pub fn into_update_phase_state(self) -> PhaseState { match self { StateMachine::Update(state) => state, _ => panic!("not in update state"), @@ -45,7 +42,7 @@ impl StateMachine { } } - pub fn into_sum_phase_state(self) -> PhaseState { + pub fn into_sum_phase_state(self) -> PhaseState { match self { StateMachine::Sum(state) => state, _ => panic!("not in sum state"), @@ -59,7 +56,7 @@ impl StateMachine { } } - pub fn into_sum2_phase_state(self) -> PhaseState { + pub fn into_sum2_phase_state(self) -> PhaseState { match self { StateMachine::Sum2(state) => state, _ => panic!("not in sum2 state"), @@ -73,7 +70,7 @@ impl StateMachine { } } - pub fn into_idle_phase_state(self) -> PhaseState { + pub fn into_idle_phase_state(self) -> PhaseState { match self { StateMachine::Idle(state) => state, _ => panic!("not in idle state"), @@ -87,7 +84,7 @@ impl StateMachine { } } - pub fn into_unmask_phase_state(self) -> PhaseState { + pub fn into_unmask_phase_state(self) -> PhaseState { match self { StateMachine::Unmask(state) => state, _ => panic!("not in unmask state"), @@ -101,7 +98,7 @@ impl StateMachine { } } - pub fn into_error_phase_state(self) -> PhaseState { + pub fn into_error_phase_state(self) -> PhaseState { match self { StateMachine::Error(state) => state, _ => panic!("not in error state"), @@ -115,7 +112,7 @@ impl StateMachine { } } - pub fn into_shutdown_phase_state(self) -> PhaseState { + pub fn into_shutdown_phase_state(self) -> PhaseState { match self { StateMachine::Shutdown(state) => state, _ => panic!("not in shutdown state"), @@ -123,57 +120,6 @@ impl StateMachine { } } -impl RequestSender { - pub async fn sum(&mut self, msg: &MessageOwned) -> SumResponse { - let (resp_tx, resp_rx) = oneshot::channel::(); - let req = Request::Sum((msg.into(), resp_tx)); - self.send(req).unwrap(); - resp_rx.await.unwrap() - } - - pub async fn update(&mut self, msg: &MessageOwned) -> UpdateResponse { - let (resp_tx, resp_rx) = oneshot::channel::(); - let req = Request::Update((msg.into(), resp_tx)); - self.send(req).unwrap(); - resp_rx.await.unwrap() - } - - pub async fn sum2(&mut self, msg: &MessageOwned) -> Sum2Response { - let (resp_tx, resp_rx) = oneshot::channel::(); - let req = Request::Sum2((msg.into(), resp_tx)); - self.send(req).unwrap(); - resp_rx.await.unwrap() - } -} - -impl<'a> From<&'a MessageOwned> for SumRequest { - fn from(msg: &'a MessageOwned) -> SumRequest { - SumRequest { - participant_pk: msg.participant_pk(), - ephm_pk: msg.ephm_pk(), - } - } -} - -impl<'a> From<&'a MessageOwned> for UpdateRequest { - fn from(msg: &'a MessageOwned) -> UpdateRequest { - UpdateRequest { - participant_pk: msg.participant_pk(), - local_seed_dict: msg.local_seed_dict(), - masked_model: msg.masked_model(), - } - } -} - -impl<'a> From<&'a MessageOwned> for Sum2Request { - fn from(msg: &'a MessageOwned) -> Sum2Request { - Sum2Request { - participant_pk: msg.participant_pk(), - mask: msg.mask(), - } - } -} - impl MessageOwned { /// Extract the participant public key from the message. pub fn participant_pk(&self) -> SumParticipantPublicKey { diff --git a/rust/src/state_machine/tests/mod.rs b/rust/src/state_machine/tests/mod.rs index 15b9b41a3..7dc725fde 100644 --- a/rust/src/state_machine/tests/mod.rs +++ b/rust/src/state_machine/tests/mod.rs @@ -28,7 +28,7 @@ async fn full_round() { let coord_pk = &coord_keys.public; let model_size = 4; - let (state_machine, mut requests, events) = StateMachineBuilder::new() + let (state_machine, requests, events) = StateMachineBuilder::new() .with_round_id(42) .with_seed(seed.clone()) .with_sum_ratio(sum_ratio) @@ -50,8 +50,8 @@ async fn full_round() { let mut summer_2 = generate_summer(&seed, sum_ratio, update_ratio); let msg_1 = summer_1.compose_sum_message(coord_pk); let msg_2 = summer_2.compose_sum_message(coord_pk); - let req_1 = async { requests.clone().sum(&msg_1).await.unwrap() }; - let req_2 = async { requests.clone().sum(&msg_2).await.unwrap() }; + let req_1 = async { requests.msg(&msg_1).await.unwrap() }; + let req_2 = async { requests.msg(&msg_2).await.unwrap() }; let transition = async { state_machine.next().await.unwrap() }; let ((), (), state_machine) = tokio::join!(req_1, req_2, transition); assert!(state_machine.is_update()); @@ -64,7 +64,7 @@ async fn full_round() { for _ in 0..3 { let updater = generate_updater(&seed, sum_ratio, update_ratio); let msg = updater.compose_update_message(*coord_pk, &sum_dict, scalar, model.clone()); - requests.update(&msg).await.unwrap(); + requests.msg(&msg).await.unwrap(); } let state_machine = transition_task.await.unwrap(); assert!(state_machine.is_sum2()); @@ -78,8 +78,8 @@ async fn full_round() { let msg_2 = summer_2 .compose_sum2_message(*coord_pk, seed_dict.get(&summer_2.pk).unwrap(), mask_length) .unwrap(); - let req_1 = async { requests.clone().sum2(&msg_1).await.unwrap() }; - let req_2 = async { requests.clone().sum2(&msg_2).await.unwrap() }; + let req_1 = async { requests.msg(&msg_1).await.unwrap() }; + let req_2 = async { requests.msg(&msg_2).await.unwrap() }; let transition = async { state_machine.next().await.unwrap() }; let ((), (), state_machine) = tokio::join!(req_1, req_2, transition); assert!(state_machine.is_unmask()); diff --git a/rust/src/utils/mod.rs b/rust/src/utils/mod.rs index 11ec91669..0cab0442e 100644 --- a/rust/src/utils/mod.rs +++ b/rust/src/utils/mod.rs @@ -1 +1,2 @@ -pub mod trace; +pub mod request; +pub use self::request::{Request, Traceable}; diff --git a/rust/src/utils/request.rs b/rust/src/utils/request.rs new file mode 100644 index 000000000..0bf0a0345 --- /dev/null +++ b/rust/src/utils/request.rs @@ -0,0 +1,132 @@ +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use tracing::Span; +use uuid::Uuid; + +/// A type that can be associated to a span, making it traceable. +pub trait Traceable { + fn make_span(&self) -> Span; +} + +// NOTE: currently `id` and `timestamp` are immutable. `span` is +// mutable, but when it is changed other copies of the RequestMetadata +// are not affected. In the future, we can have shared mutable fields +// if we want to, by adding an Arc> field. +#[derive(Debug, Clone, PartialEq)] +pub struct RequestMetadata { + /// A random UUID associated to the request + id: Uuid, + /// Time the request was created + timestamp: SystemTime, + /// Current span associated to this request + span: Span, +} + +impl RequestMetadata { + fn new() -> Self { + let id = Uuid::new_v4(); + let timestamp = SystemTime::now(); + let span = error_span!("request", id = %id, timestamp = %timestamp.duration_since(UNIX_EPOCH).unwrap_or_else(|_| Duration::new(0, 0)).as_millis()); + Self { + id, + timestamp, + span, + } + } + + /// Time elapsed since this request was created, in milli seconds + fn elapsed(&self) -> u128 { + SystemTime::now() + .duration_since(self.timestamp) + .unwrap_or_else(|_| Duration::new(0, 0)) + .as_millis() + } + + /// Return the span associated with the metadata + fn span(&self) -> Span { + self.span.clone() + } +} + +#[derive(Debug, Clone, PartialEq)] +/// A request that can be handled by a service +pub struct Request { + /// Content of the request + inner: T, + /// Metadata associated to this request + metadata: RequestMetadata, +} + +impl Request +where + T: Traceable, +{ + /// Create a new request + pub fn new(t: T) -> Self { + Self { + inner: t, + metadata: RequestMetadata::new(), + } + } + + /// Create a [`Request`] with the given metadata and inner + /// request value. + pub fn from_parts(metadata: RequestMetadata, inner: T) -> Self { + Self { metadata, inner } + } + + /// Return the metadata attached to this [`Request`] + pub fn metadata(&self) -> RequestMetadata { + self.metadata.clone() + } + + /// Turn this `Request` into a `Request`. A new span is + /// created with `::make_span` and attached to the + /// request. + pub fn map(self, f: F) -> Request + where + F: ::std::ops::FnOnce(T) -> U, + U: Traceable, + { + let Request { + mut metadata, + inner, + } = self; + let mapped = f(inner); + + // self.span() is the parent of the span associated to the + // inner type + let new_span = metadata.span().in_scope(|| mapped.make_span()); + metadata.span = new_span; + + Request { + metadata, + inner: mapped, + } + } + + /// Span associated with this request + pub fn span(&self) -> Span { + self.metadata.span() + } + + /// Time elapsed since this request was created, in milli seconds + pub fn elapsed(&self) -> u128 { + self.metadata.elapsed() + } + + /// Get a reference to the request's inner value + pub fn inner(&self) -> &T { + &self.inner + } + + /// Get a mutable reference to the request's inner value + pub fn inner_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Consume this request and return its inner value + pub fn into_inner(self) -> T { + self.inner + } +} diff --git a/rust/src/utils/trace.rs b/rust/src/utils/trace.rs deleted file mode 100644 index c53e1acd8..000000000 --- a/rust/src/utils/trace.rs +++ /dev/null @@ -1,46 +0,0 @@ -use tracing::Span; -/// A type that can be associated to a span, making it traceable. -pub trait Traceable { - type Target: Sized; - fn span(&self) -> &Span; - fn span_mut(&mut self) -> &mut Span; - fn into_inner(self) -> Self::Target; -} - -/// A wrapper that associates a tracing span to `T` -#[derive(Debug, Hash, Clone)] -pub struct Traced { - inner: T, - span: Span, -} - -impl Traced { - pub fn new(req: T, span: Span) -> Self { - Self { inner: req, span } - } - - pub fn map(self, f: F) -> Traced - where - F: ::std::ops::FnOnce(T) -> U, - { - let Traced { span, inner } = self; - Traced { - span, - inner: f(inner), - } - } -} - -impl Traceable for Traced { - type Target = T; - - fn span(&self) -> &Span { - &self.span - } - fn span_mut(&mut self) -> &mut Span { - &mut self.span - } - fn into_inner(self) -> T { - self.inner - } -}