Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
skedia committed Apr 13, 2024
1 parent dbfd849 commit 7ca1612
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 59 deletions.
22 changes: 13 additions & 9 deletions rust/worker/src/execution/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
};
use async_trait::async_trait;
use std::fmt::Debug;
use tracing::{debug_span, instrument, Instrument, Span};

/// The dispatcher is responsible for distributing tasks to worker threads.
/// It is a component that receives tasks and distributes them to worker threads.
Expand Down Expand Up @@ -97,7 +98,11 @@ impl Dispatcher {
// If a worker is waiting for a task, send it to the worker in FIFO order
// Otherwise, add it to the task queue
match self.waiters.pop() {
Some(channel) => match channel.reply_to.send(task).await {
Some(channel) => match channel
.reply_to
.send(task, Some(Span::current().clone()))
.await
{
Ok(_) => {}
Err(e) => {
println!("Error sending task to worker: {:?}", e);
Expand All @@ -116,7 +121,11 @@ impl Dispatcher {
/// when one is available
async fn handle_work_request(&mut self, request: TaskRequestMessage) {
match self.task_queue.pop() {
Some(task) => match request.reply_to.send(task).await {
Some(task) => match request
.reply_to
.send(task, Some(Span::current().clone()))
.await
{
Ok(_) => {}
Err(e) => {
println!("Error sending task to worker: {:?}", e);
Expand Down Expand Up @@ -264,13 +273,8 @@ mod tests {
#[async_trait]
impl Handler<()> for MockDispatchUser {
async fn handle(&mut self, _message: (), ctx: &ComponentContext<MockDispatchUser>) {
let task = wrap(
Box::new(MockOperator {}),
42.0,
ctx.sender.as_receiver(),
None,
);
let res = self.dispatcher.send(task).await;
let task = wrap(Box::new(MockOperator {}), 42.0, ctx.sender.as_receiver());
let res = self.dispatcher.send(task, None).await;
}
}

Expand Down
10 changes: 1 addition & 9 deletions rust/worker/src/execution/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ where
operator: Box<dyn Operator<Input, Output, Error = Error>>,
input: Input,
reply_channel: Box<dyn Receiver<Result<Output, Error>>>,
tracing_context: Option<tracing::Id>,
}

/// A message type used by the dispatcher to send tasks to worker threads.
Expand All @@ -38,7 +37,6 @@ pub(crate) type TaskMessage = Box<dyn TaskWrapper>;
#[async_trait]
pub(crate) trait TaskWrapper: Send + Debug {
async fn run(&self);
fn getTracingContext(&self) -> Option<tracing::Id>;
}

/// Implement the TaskWrapper trait for every Task. This allows us to
Expand All @@ -53,21 +51,16 @@ where
{
async fn run(&self) {
let output = self.operator.run(&self.input).await;
let res = self.reply_channel.send(output).await;
let res = self.reply_channel.send(output, None).await;
// TODO: if this errors, it means the caller was dropped
}

fn getTracingContext(&self) -> Option<tracing::Id> {
self.tracing_context.clone()
}
}

/// Wrap an operator and its input into a task message.
pub(super) fn wrap<Input, Output, Error>(
operator: Box<dyn Operator<Input, Output, Error = Error>>,
input: Input,
reply_channel: Box<dyn Receiver<Result<Output, Error>>>,
tracing_context: Option<tracing::Id>,
) -> TaskMessage
where
Error: Debug + 'static,
Expand All @@ -78,6 +71,5 @@ where
operator,
input,
reply_channel,
tracing_context,
})
}
8 changes: 4 additions & 4 deletions rust/worker/src/execution/orchestration/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ impl CompactOrchestrator {
}
};
let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp));
let task = wrap(operator, input, self_address, None);
match self.dispatcher.send(task).await {
let task = wrap(operator, input, self_address);
match self.dispatcher.send(task, None).await {
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
Expand All @@ -124,8 +124,8 @@ impl CompactOrchestrator {
let max_partition_size = 100;
let operator = PartitionOperator::new();
let input = PartitionInput::new(records, max_partition_size);
let task = wrap(operator, input, self_address, None);
match self.dispatcher.send(task).await {
let task = wrap(operator, input, self_address);
match self.dispatcher.send(task, None).await {
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
Expand Down
26 changes: 15 additions & 11 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,13 @@ impl HnswQueryOrchestrator {
async fn pull_logs(&mut self, self_address: Box<dyn Receiver<PullLogsResult>>) {
self.state = ExecutionState::PullLogs;
let operator = PullLogsOperator::new(self.log.clone());
let child_span = debug_span!(parent: Span::current(), "get collection id for segment id");
let child_span: tracing::Span =
debug_span!(parent: Span::current(), "get collection id for segment id");
let get_collection_id_future = self.get_collection_id_for_segment_id(self.segment_id);
let collection_id = match get_collection_id_future.instrument(child_span).await {
let collection_id = match get_collection_id_future
.instrument(child_span.clone())
.await
{
Some(collection_id) => collection_id,
None => {
// Log an error and reply + return
Expand All @@ -133,10 +137,10 @@ impl HnswQueryOrchestrator {
}
};
let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp));
let task = wrap(operator, input, self_address);
// Wrap the task with current span as the parent. The worker then executes it
// inside a child span with this parent.
let task = wrap(operator, input, self_address, Span::current().id().clone());
match self.dispatcher.send(task).await {
match self.dispatcher.send(task, Some(child_span.clone())).await {
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
Expand Down Expand Up @@ -193,13 +197,13 @@ impl Handler<PullLogsResult> for HnswQueryOrchestrator {
distance_metric: DistanceFunction::Euclidean,
};
let operator = Box::new(BruteForceKnnOperator {});
let task = wrap(
operator,
bf_input,
ctx.sender.as_receiver(),
Span::current().id().clone(),
);
match self.dispatcher.send(task).await {
let task = wrap(operator, bf_input, ctx.sender.as_receiver());
println!("Current span {:?}", Span::current());
match self
.dispatcher
.send(task, Some(Span::current().clone()))
.await
{
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
Expand Down
11 changes: 3 additions & 8 deletions rust/worker/src/execution/worker_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use super::{dispatcher::TaskRequestMessage, operator::TaskMessage};
use crate::system::{Component, ComponentContext, ComponentRuntime, Handler, Receiver};
use async_trait::async_trait;
use std::fmt::{Debug, Formatter, Result};
use tracing::{debug_span, instrument, Instrument};

/// A worker thread is responsible for executing tasks
/// It sends requests to the dispatcher for new tasks.
Expand Down Expand Up @@ -43,21 +42,17 @@ impl Component for WorkerThread {

async fn on_start(&mut self, ctx: &ComponentContext<Self>) -> () {
let req = TaskRequestMessage::new(ctx.sender.as_receiver());
let res = self.dispatcher.send(req).await;
let res = self.dispatcher.send(req, None).await;
// TODO: what to do with resp?
}
}

#[async_trait]
impl Handler<TaskMessage> for WorkerThread {
async fn handle(&mut self, task: TaskMessage, ctx: &ComponentContext<WorkerThread>) {
// Execute the task with the caller span
let parent_id = (*task).getTracingContext();
let child_span = debug_span!(parent: parent_id, "worker task execution");
let task_future = task.run();
task_future.instrument(child_span).await;
task.run().await;
let req: TaskRequestMessage = TaskRequestMessage::new(ctx.sender.as_receiver());
let res = self.dispatcher.send(req).await;
let res = self.dispatcher.send(req, None).await;
// TODO: task run should be able to error and we should send it as part of the result
}
}
2 changes: 1 addition & 1 deletion rust/worker/src/memberlist/memberlist_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl CustomResourceMemberlistProvider {
};

for subscriber in self.subscribers.iter() {
let _ = subscriber.send(curr_memberlist.clone()).await;
let _ = subscriber.send(curr_memberlist.clone(), None).await;
}
}
}
Expand Down
19 changes: 15 additions & 4 deletions rust/worker/src/system/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::{
use crate::system::ComponentContext;
use std::sync::Arc;
use tokio::select;
use tracing::{debug_span, instrument, span, Id, Instrument, Span};

struct Inner<C>
where
Expand Down Expand Up @@ -69,14 +70,24 @@ where
message = channel.recv() => {
match message {
Some(mut message) => {
message.handle(&mut self.handler,
&ComponentContext{
let parent_span: tracing::Span;
match message.get_tracing_context() {
Some(spn) => {
parent_span = spn;
},
None => {
parent_span = Span::current().clone();
}
}
let child_span = debug_span!(parent: parent_span, "task handler");
let component_context = ComponentContext {
system: self.inner.system.clone(),
sender: self.inner.sender.clone(),
cancellation_token: self.inner.cancellation_token.clone(),
scheduler: self.inner.scheduler.clone(),
}
).await;
};
let task_future = message.handle(&mut self.handler, &component_context);
task_future.instrument(child_span).await;
}
None => {
// TODO: Log error
Expand Down
4 changes: 2 additions & 2 deletions rust/worker/src/system/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Scheduler {
return;
}
_ = tokio::time::sleep(duration) => {
match sender.send(message).await {
match sender.send(message, None).await {
Ok(_) => {
return;
},
Expand Down Expand Up @@ -83,7 +83,7 @@ impl Scheduler {
return;
}
_ = tokio::time::sleep(duration) => {
match sender.send(message.clone()).await {
match sender.send(message.clone(), None).await {
Ok(_) => {
},
Err(e) => {
Expand Down
31 changes: 25 additions & 6 deletions rust/worker/src/system/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt::Debug;
use super::{Component, ComponentContext, Handler};
use async_trait::async_trait;
use thiserror::Error;
use tracing::Span;

// Message Wrapper
#[derive(Debug)]
Expand All @@ -11,12 +12,17 @@ where
C: Component,
{
wrapper: Box<dyn WrapperTrait<C>>,
tracing_context: Option<tracing::Span>,
}

impl<C: Component> Wrapper<C> {
pub(super) async fn handle(&mut self, component: &mut C, ctx: &ComponentContext<C>) -> () {
self.wrapper.handle(component, ctx).await;
}

pub(super) fn get_tracing_context(&self) -> Option<tracing::Span> {
return self.tracing_context.clone();
}
}

#[async_trait]
Expand All @@ -40,13 +46,14 @@ where
}
}

pub(crate) fn wrap<C, M>(message: M) -> Wrapper<C>
pub(crate) fn wrap<C, M>(message: M, tracing_context: Option<tracing::Span>) -> Wrapper<C>
where
C: Component + Handler<M>,
M: Debug + Send + 'static,
{
Wrapper {
wrapper: Box::new(Some(message)),
tracing_context,
}
}

Expand All @@ -66,12 +73,16 @@ where
Sender { sender }
}

pub(crate) async fn send<M>(&self, message: M) -> Result<(), ChannelError>
pub(crate) async fn send<M>(
&self,
message: M,
tracing_context: Option<tracing::Span>,
) -> Result<(), ChannelError>
where
C: Component + Handler<M>,
M: Debug + Send + 'static,
{
let res = self.sender.send(wrap(message)).await;
let res = self.sender.send(wrap(message, tracing_context)).await;
match res {
Ok(_) => Ok(()),
Err(_) => Err(ChannelError::SendError),
Expand Down Expand Up @@ -102,7 +113,11 @@ where

#[async_trait]
pub(crate) trait Receiver<M>: Send + Sync + Debug + ReceiverClone<M> {
async fn send(&self, message: M) -> Result<(), ChannelError>;
async fn send(
&self,
message: M,
tracing_context: Option<tracing::Span>,
) -> Result<(), ChannelError>;
}

trait ReceiverClone<M> {
Expand Down Expand Up @@ -159,8 +174,12 @@ where
C: Component + Handler<M>,
M: Send + Debug + 'static,
{
async fn send(&self, message: M) -> Result<(), ChannelError> {
let res = self.sender.send(wrap(message)).await;
async fn send(
&self,
message: M,
tracing_context: Option<tracing::Span>,
) -> Result<(), ChannelError> {
let res = self.sender.send(wrap(message, tracing_context)).await;
match res {
Ok(_) => Ok(()),
Err(_) => Err(ChannelError::SendError),
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/system/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ where
message = stream.next() => {
match message {
Some(message) => {
let res = ctx.sender.send(message).await;
let res = ctx.sender.send(message, None).await;
match res {
Ok(_) => {}
Err(e) => {
Expand Down
8 changes: 4 additions & 4 deletions rust/worker/src/system/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ mod tests {
let counter = Arc::new(AtomicUsize::new(0));
let component = TestComponent::new(10, counter.clone());
let mut handle = system.start_component(component);
handle.sender.send(1).await.unwrap();
handle.sender.send(2).await.unwrap();
handle.sender.send(3).await.unwrap();
handle.sender.send(1, None).await.unwrap();
handle.sender.send(2, None).await.unwrap();
handle.sender.send(3, None).await.unwrap();
// yield to allow the component to process the messages
tokio::task::yield_now().await;
// With the streaming data and the messages we should have 12
Expand All @@ -197,7 +197,7 @@ mod tests {
tokio::task::yield_now().await;
// Expect the component to be stopped
assert_eq!(*handle.state(), ComponentState::Stopped);
let res = handle.sender.send(4).await;
let res = handle.sender.send(4, None).await;
// Expect an error because the component is stopped
assert!(res.is_err());
}
Expand Down

0 comments on commit 7ca1612

Please sign in to comment.