Skip to content

Commit

Permalink
Clean up ShardManager/ShardQueuer/ShardRunner (#2653)
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev committed Mar 29, 2024
1 parent 84963d7 commit 6fa46ea
Show file tree
Hide file tree
Showing 19 changed files with 162 additions and 259 deletions.
6 changes: 4 additions & 2 deletions src/cache/event.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashSet;
use std::num::NonZeroU16;

use super::{Cache, CacheUpdate};
use crate::model::channel::{GuildChannel, Message};
Expand Down Expand Up @@ -491,12 +492,13 @@ impl CacheUpdate for ReadyEvent {
let mut guilds_to_remove = vec![];
let ready_guilds_hashset =
self.ready.guilds.iter().map(|status| status.id).collect::<HashSet<_>>();
let shard_data = self.ready.shard.unwrap_or_else(|| ShardInfo::new(ShardId(1), 1));
let shard_data =
self.ready.shard.unwrap_or_else(|| ShardInfo::new(ShardId(1), NonZeroU16::MIN));

for guild_entry in cache.guilds.iter() {
let guild = guild_entry.key();
// Only handle data for our shard.
if crate::utils::shard_id(*guild, shard_data.total) == shard_data.id.0
if crate::utils::shard_id(*guild, shard_data.total.get()) == shard_data.id.0
&& !ready_guilds_hashset.contains(guild)
{
guilds_to_remove.push(*guild);
Expand Down
7 changes: 4 additions & 3 deletions src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

use std::collections::{HashSet, VecDeque};
use std::hash::Hash;
use std::num::NonZeroU16;
#[cfg(feature = "temp_cache")]
use std::sync::Arc;
#[cfg(feature = "temp_cache")]
Expand Down Expand Up @@ -125,7 +126,7 @@ pub type MessageRef<'a> = CacheRef<'a, ChannelId, Message, HashMap<MessageId, Me
#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
#[derive(Debug)]
pub(crate) struct CachedShardData {
pub total: u32,
pub total: NonZeroU16,
pub connected: HashSet<ShardId>,
pub has_sent_shards_ready: bool,
}
Expand Down Expand Up @@ -281,7 +282,7 @@ impl Cache {
message_queue: DashMap::default(),

shard_data: RwLock::new(CachedShardData {
total: 1,
total: NonZeroU16::MIN,
connected: HashSet::new(),
has_sent_shards_ready: false,
}),
Expand Down Expand Up @@ -534,7 +535,7 @@ impl Cache {

/// Returns the number of shards.
#[inline]
pub fn shard_count(&self) -> u32 {
pub fn shard_count(&self) -> NonZeroU16 {
self.shard_data.read().total
}

Expand Down
9 changes: 4 additions & 5 deletions src/client/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,12 @@ fn update_cache_with_event(
#[cfg(feature = "cache")]
{
let mut shards = cache.shard_data.write();
if shards.connected.len() as u32 == shards.total && !shards.has_sent_shards_ready {
if shards.connected.len() == shards.total.get() as usize
&& !shards.has_sent_shards_ready
{
shards.has_sent_shards_ready = true;
let total = shards.total;
drop(shards);

extra_event = Some(FullEvent::ShardsReady {
total_shards: total,
total_shards: shards.total,
});
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/client/event_handler.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroU16;

use async_trait::async_trait;

use super::context::Context;
Expand Down Expand Up @@ -112,7 +114,7 @@ event_handler! {

/// Dispatched when every shard has received a Ready event
#[cfg(feature = "cache")]
ShardsReady { total_shards: u32 } => async fn shards_ready(&self, ctx: Context);
ShardsReady { total_shards: NonZeroU16 } => async fn shards_ready(&self, ctx: Context);

/// Dispatched when a channel is created.
///
Expand Down
61 changes: 22 additions & 39 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod error;
mod event_handler;

use std::future::IntoFuture;
use std::num::NonZeroU16;
use std::ops::Range;
use std::sync::Arc;
#[cfg(feature = "framework")]
Expand All @@ -30,7 +31,7 @@ use std::sync::OnceLock;
use futures::channel::mpsc::UnboundedReceiver as Receiver;
use futures::future::BoxFuture;
use futures::StreamExt as _;
use tokio::sync::{Mutex, RwLock};
use tokio::sync::RwLock;
use tracing::{debug, error, info, instrument};
use typemap_rev::{TypeMap, TypeMapKey};

Expand All @@ -57,6 +58,7 @@ use crate::internal::prelude::*;
use crate::model::gateway::GatewayIntents;
use crate::model::id::ApplicationId;
use crate::model::user::OnlineStatus;
use crate::utils::check_shard_total;

/// A builder implementing [`IntoFuture`] building a [`Client`] to interact with Discord.
#[cfg(feature = "gateway")]
Expand Down Expand Up @@ -333,13 +335,13 @@ impl IntoFuture for ClientBuilder {
let cache = Arc::new(Cache::new_with_settings(self.cache_settings));

Box::pin(async move {
let ws_url = Arc::new(Mutex::new(match http.get_gateway().await {
Ok(response) => response.url,
let (ws_url, shard_total) = match http.get_bot_gateway().await {
Ok(response) => (Arc::from(response.url), response.shards),
Err(err) => {
tracing::warn!("HTTP request to get gateway URL failed: {}", err);
"wss://gateway.discord.gg".to_string()
tracing::warn!("HTTP request to get gateway URL failed: {err}");
(Arc::from("wss://gateway.discord.gg"), NonZeroU16::MIN)
},
}));
};

#[cfg(feature = "framework")]
let framework_cell = Arc::new(OnceLock::new());
Expand All @@ -349,12 +351,10 @@ impl IntoFuture for ClientBuilder {
raw_event_handlers,
#[cfg(feature = "framework")]
framework: Arc::clone(&framework_cell),
shard_index: 0,
shard_init: 0,
shard_total: 0,
#[cfg(feature = "voice")]
voice_manager: voice_manager.clone(),
ws_url: Arc::clone(&ws_url),
shard_total,
#[cfg(feature = "cache")]
cache: Arc::clone(&cache),
http: Arc::clone(&http),
Expand Down Expand Up @@ -586,11 +586,7 @@ pub struct Client {
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
/// URL that the client's shards will use to connect to the gateway.
///
/// This is likely not important for production usage and is, at best, used for debugging.
///
/// This is wrapped in an `Arc<Mutex<T>>` so all shards will have an updated value available.
pub ws_url: Arc<Mutex<String>>,
pub ws_url: Arc<str>,
/// The cache for the client.
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
Expand Down Expand Up @@ -638,7 +634,7 @@ impl Client {
/// [gateway docs]: crate::gateway#sharding
#[instrument(skip(self))]
pub async fn start(&mut self) -> Result<()> {
self.start_connection(0, 0, 1).await
self.start_connection(0, 0, NonZeroU16::MIN).await
}

/// Establish the connection(s) and start listening for events.
Expand Down Expand Up @@ -681,8 +677,7 @@ impl Client {
pub async fn start_autosharded(&mut self) -> Result<()> {
let (end, total) = {
let res = self.http.get_bot_gateway().await?;

(res.shards - 1, res.shards)
(res.shards.get() - 1, res.shards)
};

self.start_connection(0, end, total).await
Expand Down Expand Up @@ -743,8 +738,8 @@ impl Client {
///
/// [gateway docs]: crate::gateway#sharding
#[instrument(skip(self))]
pub async fn start_shard(&mut self, shard: u32, shards: u32) -> Result<()> {
self.start_connection(shard, shard, shards).await
pub async fn start_shard(&mut self, shard: u16, shards: u16) -> Result<()> {
self.start_connection(shard, shard, check_shard_total(shards)).await
}

/// Establish sharded connections and start listening for events.
Expand Down Expand Up @@ -784,8 +779,8 @@ impl Client {
///
/// [Gateway docs]: crate::gateway#sharding
#[instrument(skip(self))]
pub async fn start_shards(&mut self, total_shards: u32) -> Result<()> {
self.start_connection(0, total_shards - 1, total_shards).await
pub async fn start_shards(&mut self, total_shards: u16) -> Result<()> {
self.start_connection(0, total_shards - 1, check_shard_total(total_shards)).await
}

/// Establish a range of sharded connections and start listening for events.
Expand Down Expand Up @@ -825,26 +820,16 @@ impl Client {
///
/// [Gateway docs]: crate::gateway#sharding
#[instrument(skip(self))]
pub async fn start_shard_range(&mut self, range: Range<u32>, total_shards: u32) -> Result<()> {
self.start_connection(range.start, range.end, total_shards).await
pub async fn start_shard_range(&mut self, range: Range<u16>, total_shards: u16) -> Result<()> {
self.start_connection(range.start, range.end, check_shard_total(total_shards)).await
}

/// Shard data layout is:
/// 0: first shard number to initialize
/// 1: shard number to initialize up to and including
/// 2: total number of shards the bot is sharding for
///
/// Not all shards need to be initialized in this process.
///
/// # Errors
///
/// Returns a [`ClientError::Shutdown`] when all shards have shutdown due to an error.
#[instrument(skip(self))]
async fn start_connection(
&mut self,
start_shard: u32,
end_shard: u32,
total_shards: u32,
start_shard: u16,
end_shard: u16,
total_shards: NonZeroU16,
) -> Result<()> {
#[cfg(feature = "voice")]
if let Some(voice_manager) = &self.voice_manager {
Expand All @@ -855,11 +840,9 @@ impl Client {

let init = end_shard - start_shard + 1;

self.shard_manager.set_shards(start_shard, init, total_shards).await;

debug!("Initializing shard info: {} - {}/{}", start_shard, init, total_shards);

if let Err(why) = self.shard_manager.initialize() {
if let Err(why) = self.shard_manager.initialize(start_shard, init, total_shards) {
error!("Failed to boot a shard: {:?}", why);
info!("Shutting down all shards");

Expand Down
8 changes: 5 additions & 3 deletions src/gateway/bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mod shard_runner_message;
mod voice;

use std::fmt;
use std::num::NonZeroU16;
use std::time::Duration as StdDuration;

pub use self::event::ShardStageUpdateEvent;
Expand All @@ -68,9 +69,10 @@ use crate::model::id::ShardId;
/// A message to be sent to the [`ShardQueuer`].
#[derive(Clone, Debug)]
pub enum ShardQueuerMessage {
/// Message to start a shard, where the 0-index element is the ID of the Shard to start and the
/// 1-index element is the total shards in use.
Start(ShardId, ShardId),
/// Message to set the shard total.
SetShardTotal(NonZeroU16),
/// Message to start a shard.
Start(ShardId),
/// Message to shutdown the shard queuer.
Shutdown,
/// Message to dequeue/shutdown a shard.
Expand Down
Loading

0 comments on commit 6fa46ea

Please sign in to comment.