From 7ca1612bae7c60652a5ce7e31016b66c17d744d9 Mon Sep 17 00:00:00 2001 From: skedia Date: Fri, 12 Apr 2024 23:29:11 -0700 Subject: [PATCH] Review comments --- rust/worker/src/execution/dispatcher.rs | 22 +++++++------ rust/worker/src/execution/operator.rs | 10 +----- .../src/execution/orchestration/compact.rs | 8 ++--- .../src/execution/orchestration/hnsw.rs | 26 +++++++++------- rust/worker/src/execution/worker_thread.rs | 11 ++----- .../src/memberlist/memberlist_provider.rs | 2 +- rust/worker/src/system/executor.rs | 19 +++++++++--- rust/worker/src/system/scheduler.rs | 4 +-- rust/worker/src/system/sender.rs | 31 +++++++++++++++---- rust/worker/src/system/system.rs | 2 +- rust/worker/src/system/types.rs | 8 ++--- 11 files changed, 84 insertions(+), 59 deletions(-) diff --git a/rust/worker/src/execution/dispatcher.rs b/rust/worker/src/execution/dispatcher.rs index 952bb15c96f..a2bb177d9b9 100644 --- a/rust/worker/src/execution/dispatcher.rs +++ b/rust/worker/src/execution/dispatcher.rs @@ -7,6 +7,7 @@ use crate::{ }; use async_trait::async_trait; use std::fmt::Debug; +use tracing::{debug_span, instrument, Instrument, Span}; /// The dispatcher is responsible for distributing tasks to worker threads. /// It is a component that receives tasks and distributes them to worker threads. @@ -97,7 +98,11 @@ impl Dispatcher { // If a worker is waiting for a task, send it to the worker in FIFO order // Otherwise, add it to the task queue match self.waiters.pop() { - Some(channel) => match channel.reply_to.send(task).await { + Some(channel) => match channel + .reply_to + .send(task, Some(Span::current().clone())) + .await + { Ok(_) => {} Err(e) => { println!("Error sending task to worker: {:?}", e); @@ -116,7 +121,11 @@ impl Dispatcher { /// when one is available async fn handle_work_request(&mut self, request: TaskRequestMessage) { match self.task_queue.pop() { - Some(task) => match request.reply_to.send(task).await { + Some(task) => match request + .reply_to + .send(task, Some(Span::current().clone())) + .await + { Ok(_) => {} Err(e) => { println!("Error sending task to worker: {:?}", e); @@ -264,13 +273,8 @@ mod tests { #[async_trait] impl Handler<()> for MockDispatchUser { async fn handle(&mut self, _message: (), ctx: &ComponentContext) { - let task = wrap( - Box::new(MockOperator {}), - 42.0, - ctx.sender.as_receiver(), - None, - ); - let res = self.dispatcher.send(task).await; + let task = wrap(Box::new(MockOperator {}), 42.0, ctx.sender.as_receiver()); + let res = self.dispatcher.send(task, None).await; } } diff --git a/rust/worker/src/execution/operator.rs b/rust/worker/src/execution/operator.rs index 6747e6205e4..cce31192b73 100644 --- a/rust/worker/src/execution/operator.rs +++ b/rust/worker/src/execution/operator.rs @@ -27,7 +27,6 @@ where operator: Box>, input: Input, reply_channel: Box>>, - tracing_context: Option, } /// A message type used by the dispatcher to send tasks to worker threads. @@ -38,7 +37,6 @@ pub(crate) type TaskMessage = Box; #[async_trait] pub(crate) trait TaskWrapper: Send + Debug { async fn run(&self); - fn getTracingContext(&self) -> Option; } /// Implement the TaskWrapper trait for every Task. This allows us to @@ -53,13 +51,9 @@ where { async fn run(&self) { let output = self.operator.run(&self.input).await; - let res = self.reply_channel.send(output).await; + let res = self.reply_channel.send(output, None).await; // TODO: if this errors, it means the caller was dropped } - - fn getTracingContext(&self) -> Option { - self.tracing_context.clone() - } } /// Wrap an operator and its input into a task message. @@ -67,7 +61,6 @@ pub(super) fn wrap( operator: Box>, input: Input, reply_channel: Box>>, - tracing_context: Option, ) -> TaskMessage where Error: Debug + 'static, @@ -78,6 +71,5 @@ where operator, input, reply_channel, - tracing_context, }) } diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index 473bc98ed0e..71138c9b4d8 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -105,8 +105,8 @@ impl CompactOrchestrator { } }; let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp)); - let task = wrap(operator, input, self_address, None); - match self.dispatcher.send(task).await { + let task = wrap(operator, input, self_address); + match self.dispatcher.send(task, None).await { Ok(_) => (), Err(e) => { // TODO: log an error and reply to caller @@ -124,8 +124,8 @@ impl CompactOrchestrator { let max_partition_size = 100; let operator = PartitionOperator::new(); let input = PartitionInput::new(records, max_partition_size); - let task = wrap(operator, input, self_address, None); - match self.dispatcher.send(task).await { + let task = wrap(operator, input, self_address); + match self.dispatcher.send(task, None).await { Ok(_) => (), Err(e) => { // TODO: log an error and reply to caller diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 31293d75910..b24f934d56f 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -114,9 +114,13 @@ impl HnswQueryOrchestrator { async fn pull_logs(&mut self, self_address: Box>) { self.state = ExecutionState::PullLogs; let operator = PullLogsOperator::new(self.log.clone()); - let child_span = debug_span!(parent: Span::current(), "get collection id for segment id"); + let child_span: tracing::Span = + debug_span!(parent: Span::current(), "get collection id for segment id"); let get_collection_id_future = self.get_collection_id_for_segment_id(self.segment_id); - let collection_id = match get_collection_id_future.instrument(child_span).await { + let collection_id = match get_collection_id_future + .instrument(child_span.clone()) + .await + { Some(collection_id) => collection_id, None => { // Log an error and reply + return @@ -133,10 +137,10 @@ impl HnswQueryOrchestrator { } }; let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp)); + let task = wrap(operator, input, self_address); // Wrap the task with current span as the parent. The worker then executes it // inside a child span with this parent. - let task = wrap(operator, input, self_address, Span::current().id().clone()); - match self.dispatcher.send(task).await { + match self.dispatcher.send(task, Some(child_span.clone())).await { Ok(_) => (), Err(e) => { // TODO: log an error and reply to caller @@ -193,13 +197,13 @@ impl Handler for HnswQueryOrchestrator { distance_metric: DistanceFunction::Euclidean, }; let operator = Box::new(BruteForceKnnOperator {}); - let task = wrap( - operator, - bf_input, - ctx.sender.as_receiver(), - Span::current().id().clone(), - ); - match self.dispatcher.send(task).await { + let task = wrap(operator, bf_input, ctx.sender.as_receiver()); + println!("Current span {:?}", Span::current()); + match self + .dispatcher + .send(task, Some(Span::current().clone())) + .await + { Ok(_) => (), Err(e) => { // TODO: log an error and reply to caller diff --git a/rust/worker/src/execution/worker_thread.rs b/rust/worker/src/execution/worker_thread.rs index 8ff1fd8c919..b866d392307 100644 --- a/rust/worker/src/execution/worker_thread.rs +++ b/rust/worker/src/execution/worker_thread.rs @@ -2,7 +2,6 @@ use super::{dispatcher::TaskRequestMessage, operator::TaskMessage}; use crate::system::{Component, ComponentContext, ComponentRuntime, Handler, Receiver}; use async_trait::async_trait; use std::fmt::{Debug, Formatter, Result}; -use tracing::{debug_span, instrument, Instrument}; /// A worker thread is responsible for executing tasks /// It sends requests to the dispatcher for new tasks. @@ -43,7 +42,7 @@ impl Component for WorkerThread { async fn on_start(&mut self, ctx: &ComponentContext) -> () { let req = TaskRequestMessage::new(ctx.sender.as_receiver()); - let res = self.dispatcher.send(req).await; + let res = self.dispatcher.send(req, None).await; // TODO: what to do with resp? } } @@ -51,13 +50,9 @@ impl Component for WorkerThread { #[async_trait] impl Handler for WorkerThread { async fn handle(&mut self, task: TaskMessage, ctx: &ComponentContext) { - // Execute the task with the caller span - let parent_id = (*task).getTracingContext(); - let child_span = debug_span!(parent: parent_id, "worker task execution"); - let task_future = task.run(); - task_future.instrument(child_span).await; + task.run().await; let req: TaskRequestMessage = TaskRequestMessage::new(ctx.sender.as_receiver()); - let res = self.dispatcher.send(req).await; + let res = self.dispatcher.send(req, None).await; // TODO: task run should be able to error and we should send it as part of the result } } diff --git a/rust/worker/src/memberlist/memberlist_provider.rs b/rust/worker/src/memberlist/memberlist_provider.rs index 9b529edb8bb..47d93087969 100644 --- a/rust/worker/src/memberlist/memberlist_provider.rs +++ b/rust/worker/src/memberlist/memberlist_provider.rs @@ -175,7 +175,7 @@ impl CustomResourceMemberlistProvider { }; for subscriber in self.subscribers.iter() { - let _ = subscriber.send(curr_memberlist.clone()).await; + let _ = subscriber.send(curr_memberlist.clone(), None).await; } } } diff --git a/rust/worker/src/system/executor.rs b/rust/worker/src/system/executor.rs index 4877273b70e..3bdaf2e2392 100644 --- a/rust/worker/src/system/executor.rs +++ b/rust/worker/src/system/executor.rs @@ -7,6 +7,7 @@ use super::{ use crate::system::ComponentContext; use std::sync::Arc; use tokio::select; +use tracing::{debug_span, instrument, span, Id, Instrument, Span}; struct Inner where @@ -69,14 +70,24 @@ where message = channel.recv() => { match message { Some(mut message) => { - message.handle(&mut self.handler, - &ComponentContext{ + let parent_span: tracing::Span; + match message.get_tracing_context() { + Some(spn) => { + parent_span = spn; + }, + None => { + parent_span = Span::current().clone(); + } + } + let child_span = debug_span!(parent: parent_span, "task handler"); + let component_context = ComponentContext { system: self.inner.system.clone(), sender: self.inner.sender.clone(), cancellation_token: self.inner.cancellation_token.clone(), scheduler: self.inner.scheduler.clone(), - } - ).await; + }; + let task_future = message.handle(&mut self.handler, &component_context); + task_future.instrument(child_span).await; } None => { // TODO: Log error diff --git a/rust/worker/src/system/scheduler.rs b/rust/worker/src/system/scheduler.rs index 34cb3b1872a..bcd7dbe308f 100644 --- a/rust/worker/src/system/scheduler.rs +++ b/rust/worker/src/system/scheduler.rs @@ -42,7 +42,7 @@ impl Scheduler { return; } _ = tokio::time::sleep(duration) => { - match sender.send(message).await { + match sender.send(message, None).await { Ok(_) => { return; }, @@ -83,7 +83,7 @@ impl Scheduler { return; } _ = tokio::time::sleep(duration) => { - match sender.send(message.clone()).await { + match sender.send(message.clone(), None).await { Ok(_) => { }, Err(e) => { diff --git a/rust/worker/src/system/sender.rs b/rust/worker/src/system/sender.rs index d9fb0785418..919f8757075 100644 --- a/rust/worker/src/system/sender.rs +++ b/rust/worker/src/system/sender.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use super::{Component, ComponentContext, Handler}; use async_trait::async_trait; use thiserror::Error; +use tracing::Span; // Message Wrapper #[derive(Debug)] @@ -11,12 +12,17 @@ where C: Component, { wrapper: Box>, + tracing_context: Option, } impl Wrapper { pub(super) async fn handle(&mut self, component: &mut C, ctx: &ComponentContext) -> () { self.wrapper.handle(component, ctx).await; } + + pub(super) fn get_tracing_context(&self) -> Option { + return self.tracing_context.clone(); + } } #[async_trait] @@ -40,13 +46,14 @@ where } } -pub(crate) fn wrap(message: M) -> Wrapper +pub(crate) fn wrap(message: M, tracing_context: Option) -> Wrapper where C: Component + Handler, M: Debug + Send + 'static, { Wrapper { wrapper: Box::new(Some(message)), + tracing_context, } } @@ -66,12 +73,16 @@ where Sender { sender } } - pub(crate) async fn send(&self, message: M) -> Result<(), ChannelError> + pub(crate) async fn send( + &self, + message: M, + tracing_context: Option, + ) -> Result<(), ChannelError> where C: Component + Handler, M: Debug + Send + 'static, { - let res = self.sender.send(wrap(message)).await; + let res = self.sender.send(wrap(message, tracing_context)).await; match res { Ok(_) => Ok(()), Err(_) => Err(ChannelError::SendError), @@ -102,7 +113,11 @@ where #[async_trait] pub(crate) trait Receiver: Send + Sync + Debug + ReceiverClone { - async fn send(&self, message: M) -> Result<(), ChannelError>; + async fn send( + &self, + message: M, + tracing_context: Option, + ) -> Result<(), ChannelError>; } trait ReceiverClone { @@ -159,8 +174,12 @@ where C: Component + Handler, M: Send + Debug + 'static, { - async fn send(&self, message: M) -> Result<(), ChannelError> { - let res = self.sender.send(wrap(message)).await; + async fn send( + &self, + message: M, + tracing_context: Option, + ) -> Result<(), ChannelError> { + let res = self.sender.send(wrap(message, tracing_context)).await; match res { Ok(_) => Ok(()), Err(_) => Err(ChannelError::SendError), diff --git a/rust/worker/src/system/system.rs b/rust/worker/src/system/system.rs index 0179b4459cf..ff6df93a667 100644 --- a/rust/worker/src/system/system.rs +++ b/rust/worker/src/system/system.rs @@ -104,7 +104,7 @@ where message = stream.next() => { match message { Some(message) => { - let res = ctx.sender.send(message).await; + let res = ctx.sender.send(message, None).await; match res { Ok(_) => {} Err(e) => { diff --git a/rust/worker/src/system/types.rs b/rust/worker/src/system/types.rs index e1e2228b146..accd89dfd8c 100644 --- a/rust/worker/src/system/types.rs +++ b/rust/worker/src/system/types.rs @@ -185,9 +185,9 @@ mod tests { let counter = Arc::new(AtomicUsize::new(0)); let component = TestComponent::new(10, counter.clone()); let mut handle = system.start_component(component); - handle.sender.send(1).await.unwrap(); - handle.sender.send(2).await.unwrap(); - handle.sender.send(3).await.unwrap(); + handle.sender.send(1, None).await.unwrap(); + handle.sender.send(2, None).await.unwrap(); + handle.sender.send(3, None).await.unwrap(); // yield to allow the component to process the messages tokio::task::yield_now().await; // With the streaming data and the messages we should have 12 @@ -197,7 +197,7 @@ mod tests { tokio::task::yield_now().await; // Expect the component to be stopped assert_eq!(*handle.state(), ComponentState::Stopped); - let res = handle.sender.send(4).await; + let res = handle.sender.send(4, None).await; // Expect an error because the component is stopped assert!(res.is_err()); }