From 8b7a07a97d1f3ef53607ebddb67c1353b02d6238 Mon Sep 17 00:00:00 2001 From: Gnome! Date: Sat, 9 Dec 2023 20:21:26 +0000 Subject: [PATCH] Clean up `ShardManager`/`ShardQueuer`/`ShardRunner` (#2653) --- src/cache/event.rs | 6 +- src/cache/mod.rs | 7 ++- src/client/dispatch.rs | 9 ++- src/client/event_handler.rs | 3 +- src/client/mod.rs | 61 +++++++------------ src/gateway/bridge/mod.rs | 8 ++- src/gateway/bridge/shard_manager.rs | 88 ++++++++++----------------- src/gateway/bridge/shard_messenger.rs | 87 +++++--------------------- src/gateway/bridge/shard_queuer.rs | 39 ++++++------ src/gateway/bridge/shard_runner.rs | 4 +- src/gateway/bridge/voice.rs | 10 +-- src/gateway/shard.rs | 57 ++++++----------- src/model/gateway.rs | 8 +-- src/model/guild/guild_id.rs | 6 +- src/model/guild/mod.rs | 4 +- src/model/guild/partial_guild.rs | 4 +- src/model/id.rs | 2 +- src/model/invite.rs | 4 +- src/utils/mod.rs | 13 +++- 19 files changed, 161 insertions(+), 259 deletions(-) diff --git a/src/cache/event.rs b/src/cache/event.rs index 483160e2659..f8981e4fd49 100644 --- a/src/cache/event.rs +++ b/src/cache/event.rs @@ -1,4 +1,5 @@ use std::collections::HashSet; +use std::num::NonZeroU16; use super::{Cache, CacheUpdate}; use crate::model::channel::{GuildChannel, Message}; @@ -455,12 +456,13 @@ impl CacheUpdate for ReadyEvent { let mut guilds_to_remove = vec![]; let ready_guilds_hashset = self.ready.guilds.iter().map(|status| status.id).collect::>(); - let shard_data = self.ready.shard.unwrap_or_else(|| ShardInfo::new(ShardId(1), 1)); + let shard_data = + self.ready.shard.unwrap_or_else(|| ShardInfo::new(ShardId(1), NonZeroU16::MIN)); for guild_entry in cache.guilds.iter() { let guild = guild_entry.key(); // Only handle data for our shard. - if crate::utils::shard_id(*guild, shard_data.total) == shard_data.id.0 + if crate::utils::shard_id(*guild, shard_data.total.get()) == shard_data.id.0 && !ready_guilds_hashset.contains(guild) { guilds_to_remove.push(*guild); diff --git a/src/cache/mod.rs b/src/cache/mod.rs index a705e2b241c..c569666dd8f 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -25,6 +25,7 @@ use std::collections::{HashMap, HashSet, VecDeque}; use std::hash::Hash; +use std::num::NonZeroU16; #[cfg(feature = "temp_cache")] use std::sync::Arc; #[cfg(feature = "temp_cache")] @@ -125,7 +126,7 @@ pub type MessageRef<'a> = CacheRef<'a, ChannelId, Message, HashMap, pub has_sent_shards_ready: bool, } @@ -281,7 +282,7 @@ impl Cache { message_queue: DashMap::default(), shard_data: RwLock::new(CachedShardData { - total: 1, + total: NonZeroU16::MIN, connected: HashSet::new(), has_sent_shards_ready: false, }), @@ -539,7 +540,7 @@ impl Cache { /// Returns the number of shards. #[inline] - pub fn shard_count(&self) -> u32 { + pub fn shard_count(&self) -> NonZeroU16 { self.shard_data.read().total } diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs index eb0802ee003..5851139fe6c 100644 --- a/src/client/dispatch.rs +++ b/src/client/dispatch.rs @@ -348,13 +348,12 @@ fn update_cache_with_event( #[cfg(feature = "cache")] { let mut shards = cache.shard_data.write(); - if shards.connected.len() as u32 == shards.total && !shards.has_sent_shards_ready { + if shards.connected.len() == shards.total.get() as usize + && !shards.has_sent_shards_ready + { shards.has_sent_shards_ready = true; - let total = shards.total; - drop(shards); - extra_event = Some(FullEvent::ShardsReady { - total_shards: total, + total_shards: shards.total, }); } } diff --git a/src/client/event_handler.rs b/src/client/event_handler.rs index f18807ccc76..3bed39a38a3 100644 --- a/src/client/event_handler.rs +++ b/src/client/event_handler.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::num::NonZeroU16; use async_trait::async_trait; @@ -117,7 +118,7 @@ event_handler! { /// Dispatched when every shard has received a Ready event #[cfg(feature = "cache")] - ShardsReady { total_shards: u32 } => async fn shards_ready(&self, ctx: Context); + ShardsReady { total_shards: NonZeroU16 } => async fn shards_ready(&self, ctx: Context); /// Dispatched when a channel is created. /// diff --git a/src/client/mod.rs b/src/client/mod.rs index 0074c50193e..0850cb57060 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -22,6 +22,7 @@ mod error; mod event_handler; use std::future::IntoFuture; +use std::num::NonZeroU16; use std::ops::Range; use std::sync::Arc; #[cfg(feature = "framework")] @@ -30,7 +31,7 @@ use std::sync::OnceLock; use futures::channel::mpsc::UnboundedReceiver as Receiver; use futures::future::BoxFuture; use futures::StreamExt as _; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::RwLock; use tracing::{debug, error, info, instrument}; use typemap_rev::{TypeMap, TypeMapKey}; @@ -57,6 +58,7 @@ use crate::internal::prelude::*; use crate::model::gateway::GatewayIntents; use crate::model::id::ApplicationId; use crate::model::user::OnlineStatus; +use crate::utils::check_shard_total; /// A builder implementing [`IntoFuture`] building a [`Client`] to interact with Discord. #[cfg(feature = "gateway")] @@ -333,13 +335,13 @@ impl IntoFuture for ClientBuilder { let cache = Arc::new(Cache::new_with_settings(self.cache_settings)); Box::pin(async move { - let ws_url = Arc::new(Mutex::new(match http.get_gateway().await { - Ok(response) => response.url, + let (ws_url, shard_total) = match http.get_bot_gateway().await { + Ok(response) => (Arc::from(response.url), response.shards), Err(err) => { - tracing::warn!("HTTP request to get gateway URL failed: {}", err); - "wss://gateway.discord.gg".to_string() + tracing::warn!("HTTP request to get gateway URL failed: {err}"); + (Arc::from("wss://gateway.discord.gg"), NonZeroU16::MIN) }, - })); + }; #[cfg(feature = "framework")] let framework_cell = Arc::new(OnceLock::new()); @@ -349,12 +351,10 @@ impl IntoFuture for ClientBuilder { raw_event_handlers, #[cfg(feature = "framework")] framework: Arc::clone(&framework_cell), - shard_index: 0, - shard_init: 0, - shard_total: 0, #[cfg(feature = "voice")] voice_manager: voice_manager.as_ref().map(Arc::clone), ws_url: Arc::clone(&ws_url), + shard_total, #[cfg(feature = "cache")] cache: Arc::clone(&cache), http: Arc::clone(&http), @@ -586,11 +586,7 @@ pub struct Client { #[cfg(feature = "voice")] pub voice_manager: Option>, /// URL that the client's shards will use to connect to the gateway. - /// - /// This is likely not important for production usage and is, at best, used for debugging. - /// - /// This is wrapped in an `Arc>` so all shards will have an updated value available. - pub ws_url: Arc>, + pub ws_url: Arc, /// The cache for the client. #[cfg(feature = "cache")] pub cache: Arc, @@ -638,7 +634,7 @@ impl Client { /// [gateway docs]: crate::gateway#sharding #[instrument(skip(self))] pub async fn start(&mut self) -> Result<()> { - self.start_connection(0, 0, 1).await + self.start_connection(0, 0, NonZeroU16::MIN).await } /// Establish the connection(s) and start listening for events. @@ -681,8 +677,7 @@ impl Client { pub async fn start_autosharded(&mut self) -> Result<()> { let (end, total) = { let res = self.http.get_bot_gateway().await?; - - (res.shards - 1, res.shards) + (res.shards.get() - 1, res.shards) }; self.start_connection(0, end, total).await @@ -743,8 +738,8 @@ impl Client { /// /// [gateway docs]: crate::gateway#sharding #[instrument(skip(self))] - pub async fn start_shard(&mut self, shard: u32, shards: u32) -> Result<()> { - self.start_connection(shard, shard, shards).await + pub async fn start_shard(&mut self, shard: u16, shards: u16) -> Result<()> { + self.start_connection(shard, shard, check_shard_total(shards)).await } /// Establish sharded connections and start listening for events. @@ -784,8 +779,8 @@ impl Client { /// /// [Gateway docs]: crate::gateway#sharding #[instrument(skip(self))] - pub async fn start_shards(&mut self, total_shards: u32) -> Result<()> { - self.start_connection(0, total_shards - 1, total_shards).await + pub async fn start_shards(&mut self, total_shards: u16) -> Result<()> { + self.start_connection(0, total_shards - 1, check_shard_total(total_shards)).await } /// Establish a range of sharded connections and start listening for events. @@ -825,26 +820,16 @@ impl Client { /// /// [Gateway docs]: crate::gateway#sharding #[instrument(skip(self))] - pub async fn start_shard_range(&mut self, range: Range, total_shards: u32) -> Result<()> { - self.start_connection(range.start, range.end, total_shards).await + pub async fn start_shard_range(&mut self, range: Range, total_shards: u16) -> Result<()> { + self.start_connection(range.start, range.end, check_shard_total(total_shards)).await } - /// Shard data layout is: - /// 0: first shard number to initialize - /// 1: shard number to initialize up to and including - /// 2: total number of shards the bot is sharding for - /// - /// Not all shards need to be initialized in this process. - /// - /// # Errors - /// - /// Returns a [`ClientError::Shutdown`] when all shards have shutdown due to an error. #[instrument(skip(self))] async fn start_connection( &mut self, - start_shard: u32, - end_shard: u32, - total_shards: u32, + start_shard: u16, + end_shard: u16, + total_shards: NonZeroU16, ) -> Result<()> { #[cfg(feature = "voice")] if let Some(voice_manager) = &self.voice_manager { @@ -855,11 +840,9 @@ impl Client { let init = end_shard - start_shard + 1; - self.shard_manager.set_shards(start_shard, init, total_shards).await; - debug!("Initializing shard info: {} - {}/{}", start_shard, init, total_shards); - if let Err(why) = self.shard_manager.initialize() { + if let Err(why) = self.shard_manager.initialize(start_shard, init, total_shards) { error!("Failed to boot a shard: {:?}", why); info!("Shutting down all shards"); diff --git a/src/gateway/bridge/mod.rs b/src/gateway/bridge/mod.rs index 45543d26711..4fc51c1cd87 100644 --- a/src/gateway/bridge/mod.rs +++ b/src/gateway/bridge/mod.rs @@ -50,6 +50,7 @@ mod shard_runner_message; mod voice; use std::fmt; +use std::num::NonZeroU16; use std::time::Duration as StdDuration; pub use self::event::ShardStageUpdateEvent; @@ -68,9 +69,10 @@ use crate::model::id::ShardId; /// A message to be sent to the [`ShardQueuer`]. #[derive(Clone, Debug)] pub enum ShardQueuerMessage { - /// Message to start a shard, where the 0-index element is the ID of the Shard to start and the - /// 1-index element is the total shards in use. - Start(ShardId, ShardId), + /// Message to set the shard total. + SetShardTotal(NonZeroU16), + /// Message to start a shard. + Start(ShardId), /// Message to shutdown the shard queuer. Shutdown, /// Message to dequeue/shutdown a shard. diff --git a/src/gateway/bridge/shard_manager.rs b/src/gateway/bridge/shard_manager.rs index d9ccb8081bd..9dfa67f06cf 100644 --- a/src/gateway/bridge/shard_manager.rs +++ b/src/gateway/bridge/shard_manager.rs @@ -1,5 +1,5 @@ use std::collections::{HashMap, VecDeque}; -use std::sync::atomic::{AtomicU32, Ordering}; +use std::num::NonZeroU16; use std::sync::Arc; #[cfg(feature = "framework")] use std::sync::OnceLock; @@ -65,7 +65,10 @@ use crate::model::gateway::GatewayIntents; /// impl RawEventHandler for Handler {} /// /// # let http: Arc = unimplemented!(); -/// let ws_url = Arc::new(Mutex::new(http.get_gateway().await?.url)); +/// let gateway_info = http.get_bot_gateway().await?; +/// +/// let shard_total = gateway_info.shards; +/// let ws_url = Arc::from(gateway_info.url); /// let data = Arc::new(RwLock::new(TypeMap::new())); /// let event_handler = Arc::new(Handler) as Arc; /// let framework = Arc::new(StandardFramework::new()) as Arc; @@ -75,15 +78,10 @@ use crate::model::gateway::GatewayIntents; /// event_handlers: vec![event_handler], /// raw_event_handlers: vec![], /// framework: Arc::new(OnceLock::from(framework)), -/// // the shard index to start initiating from -/// shard_index: 0, -/// // the number of shards to initiate (this initiates 0, 1, and 2) -/// shard_init: 3, -/// // the total number of shards in use -/// shard_total: 5, /// # #[cfg(feature = "voice")] /// # voice_manager: None, /// ws_url, +/// shard_total, /// # #[cfg(feature = "cache")] /// # cache: unimplemented!(), /// # http, @@ -103,13 +101,6 @@ pub struct ShardManager { /// **Note**: It is highly unrecommended to mutate this yourself unless you need to. Instead /// prefer to use methods on this struct that are provided where possible. pub runners: Arc>>, - /// The index of the first shard to initialize, 0-indexed. - // Atomics are used here to allow for mutation without requiring a mutable reference to self. - shard_index: AtomicU32, - /// The number of shards to initialize. - shard_init: AtomicU32, - /// The total shards in use, 1-indexed. - shard_total: AtomicU32, shard_queuer: Sender, // We can safely use a Mutex for this field, as it is only ever used in one single place // and only is ever used to receive a single message @@ -131,10 +122,7 @@ impl ShardManager { let manager = Arc::new(Self { return_value_tx: Mutex::new(return_value_tx), - shard_index: AtomicU32::new(opt.shard_index), - shard_init: AtomicU32::new(opt.shard_init), shard_queuer: shard_queue_tx, - shard_total: AtomicU32::new(opt.shard_total), shard_shutdown: Mutex::new(shutdown_recv), shard_shutdown_send: shutdown_send, runners: Arc::clone(&runners), @@ -155,6 +143,7 @@ impl ShardManager { #[cfg(feature = "voice")] voice_manager: opt.voice_manager, ws_url: opt.ws_url, + shard_total: opt.shard_total, #[cfg(feature = "cache")] cache: opt.cache, http: opt.http, @@ -182,34 +171,22 @@ impl ShardManager { /// This will communicate shard boots with the [`ShardQueuer`] so that they are properly /// queued. #[instrument(skip(self))] - pub fn initialize(&self) -> Result<()> { - let shard_index = self.shard_index.load(Ordering::Relaxed); - let shard_init = self.shard_init.load(Ordering::Relaxed); - let shard_total = self.shard_total.load(Ordering::Relaxed); - + pub fn initialize( + &self, + shard_index: u16, + shard_init: u16, + shard_total: NonZeroU16, + ) -> Result<()> { let shard_to = shard_index + shard_init; + self.set_shard_total(shard_total); for shard_id in shard_index..shard_to { - self.boot([ShardId(shard_id), ShardId(shard_total)]); + self.boot(ShardId(shard_id)); } Ok(()) } - /// Sets the new sharding information for the manager. - /// - /// This will shutdown all existing shards. - /// - /// This will _not_ instantiate the new shards. - #[instrument(skip(self))] - pub async fn set_shards(&self, index: u32, init: u32, total: u32) { - self.shutdown_all().await; - - self.shard_index.store(index, Ordering::Relaxed); - self.shard_init.store(init, Ordering::Relaxed); - self.shard_total.store(total, Ordering::Relaxed); - } - /// Restarts a shard runner. /// /// This sends a shutdown signal to a shard's associated [`ShardRunner`], and then queues a @@ -232,12 +209,9 @@ impl ShardManager { /// [`ShardRunner`]: super::ShardRunner #[instrument(skip(self))] pub async fn restart(&self, shard_id: ShardId) { - info!("Restarting shard {}", shard_id); + info!("Restarting shard {shard_id}"); self.shutdown(shard_id, 4000).await; - - let shard_total = self.shard_total.load(Ordering::Relaxed); - - self.boot([shard_id, ShardId(shard_total)]); + self.boot(shard_id); } /// Returns the [`ShardId`]s of the shards that have been instantiated and currently have a @@ -324,12 +298,18 @@ impl ShardManager { drop(self.return_value_tx.lock().await.unbounded_send(Ok(()))); } - #[instrument(skip(self))] - fn boot(&self, shard_info: [ShardId; 2]) { - info!("Telling shard queuer to start shard {}", shard_info[0]); + fn set_shard_total(&self, shard_total: NonZeroU16) { + info!("Setting shard total to {shard_total}"); - let msg = ShardQueuerMessage::Start(shard_info[0], shard_info[1]); + let msg = ShardQueuerMessage::SetShardTotal(shard_total); + drop(self.shard_queuer.unbounded_send(msg)); + } + #[instrument(skip(self))] + fn boot(&self, shard_id: ShardId) { + info!("Telling shard queuer to start shard {shard_id}"); + + let msg = ShardQueuerMessage::Start(shard_id); drop(self.shard_queuer.unbounded_send(msg)); } @@ -351,10 +331,10 @@ impl ShardManager { } } - pub async fn restart_shard(&self, id: ShardId) { - self.restart(id).await; - if let Err(e) = self.shard_shutdown_send.unbounded_send(id) { - tracing::warn!("failed to notify about finished shutdown: {}", e); + pub async fn restart_shard(&self, shard_id: ShardId) { + self.restart(shard_id).await; + if let Err(e) = self.shard_shutdown_send.unbounded_send(shard_id) { + tracing::warn!("failed to notify about finished shutdown: {e}"); } } @@ -389,12 +369,10 @@ pub struct ShardManagerOptions { pub raw_event_handlers: Vec>, #[cfg(feature = "framework")] pub framework: Arc>>, - pub shard_index: u32, - pub shard_init: u32, - pub shard_total: u32, #[cfg(feature = "voice")] pub voice_manager: Option>, - pub ws_url: Arc>, + pub ws_url: Arc, + pub shard_total: NonZeroU16, #[cfg(feature = "cache")] pub cache: Arc, pub http: Arc, diff --git a/src/gateway/bridge/shard_messenger.rs b/src/gateway/bridge/shard_messenger.rs index 685cc7eb0ed..efd1b6c9133 100644 --- a/src/gateway/bridge/shard_messenger.rs +++ b/src/gateway/bridge/shard_messenger.rs @@ -58,24 +58,17 @@ impl ShardMessenger { /// parameter: /// /// ```rust,no_run - /// # use tokio::sync::Mutex; - /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; - /// # use serenity::model::id::ShardId; /// # use serenity::gateway::{ChunkGuildFilter, Shard}; - /// # use std::sync::Arc; - /// # - /// # async fn run() -> Result<(), Box> { - /// # let mutex = Arc::new(Mutex::new("".to_string())); - /// # - /// # let shard_info = ShardInfo { - /// # id: ShardId(0), - /// # total: 1, - /// # }; - /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all(), None).await?; - /// # + /// # async fn run(mut shard: Shard) -> Result<(), Box> { /// use serenity::model::id::GuildId; /// - /// shard.chunk_guild(GuildId::new(81384788765712384), Some(2000), false, ChunkGuildFilter::None, None); + /// shard.chunk_guild( + /// GuildId::new(81384788765712384), + /// Some(2000), + /// false, + /// ChunkGuildFilter::None, + /// None, + /// ); /// # Ok(()) /// # } /// ``` @@ -84,22 +77,8 @@ impl ShardMessenger { /// and a nonce of `"request"`: /// /// ```rust,no_run - /// # use tokio::sync::Mutex; - /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; - /// # use serenity::model::id::ShardId; /// # use serenity::gateway::{ChunkGuildFilter, Shard}; - /// # use std::sync::Arc; - /// # - /// # async fn run() -> Result<(), Box> { - /// # let mutex = Arc::new(Mutex::new("".to_string())); - /// # - /// # let shard_info = ShardInfo { - /// # id: ShardId(0), - /// # total: 1, - /// # }; - /// # - /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all(), None).await?;; - /// # + /// # async fn run(mut shard: Shard) -> Result<(), Box> { /// use serenity::model::id::GuildId; /// /// shard.chunk_guild( @@ -138,21 +117,8 @@ impl ShardMessenger { /// Setting the current activity to playing `"Heroes of the Storm"`: /// /// ```rust,no_run - /// # use tokio::sync::Mutex; - /// # use serenity::gateway::{Shard}; - /// # use serenity::model::id::ShardId; - /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; - /// # use std::sync::Arc; - /// # - /// # async fn run() -> Result<(), Box> { - /// # let mutex = Arc::new(Mutex::new("".to_string())); - /// # - /// # let shard_info = ShardInfo { - /// # id: ShardId(0), - /// # total: 1, - /// # }; - /// # - /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all(), None).await?; + /// # use serenity::gateway::Shard; + /// # async fn run(mut shard: Shard) -> Result<(), Box> { /// use serenity::gateway::ActivityData; /// /// shard.set_activity(Some(ActivityData::playing("Heroes of the Storm"))); @@ -172,20 +138,8 @@ impl ShardMessenger { /// Set the current user as playing `"Heroes of the Storm"` and being online: /// /// ```rust,ignore - /// # use tokio::sync::Mutex; /// # use serenity::gateway::Shard; - /// # use std::sync::Arc; - /// # - /// # async fn run() -> Result<(), Box> { - /// # let mutex = Arc::new(Mutex::new("".to_string())); - /// # - /// # let shard_info = ShardInfo { - /// # id: 0, - /// # total: 1, - /// # }; - /// # - /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, None).await?; - /// # + /// # async fn run(shard: Shard) -> Result<(), Box> { /// use serenity::gateway::ActivityData; /// use serenity::model::user::OnlineStatus; /// @@ -214,21 +168,8 @@ impl ShardMessenger { /// Setting the current online status for the shard to [`DoNotDisturb`]. /// /// ```rust,no_run - /// # use tokio::sync::Mutex; - /// # use serenity::gateway::{Shard}; - /// # use serenity::model::id::ShardId; - /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; - /// # use std::sync::Arc; - /// # - /// # async fn run() -> Result<(), Box> { - /// # let mutex = Arc::new(Mutex::new("".to_string())); - /// # let shard_info = ShardInfo { - /// # id: ShardId(0), - /// # total: 1, - /// # }; - /// # - /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all(), None).await?; - /// # + /// # use serenity::gateway::Shard; + /// # async fn run(mut shard: Shard) -> Result<(), Box> { /// use serenity::model::user::OnlineStatus; /// /// shard.set_status(OnlineStatus::DoNotDisturb); diff --git a/src/gateway/bridge/shard_queuer.rs b/src/gateway/bridge/shard_queuer.rs index 03ade62ecb7..cc72dd0faf6 100644 --- a/src/gateway/bridge/shard_queuer.rs +++ b/src/gateway/bridge/shard_queuer.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, VecDeque}; +use std::num::NonZeroU16; use std::sync::Arc; #[cfg(feature = "framework")] use std::sync::OnceLock; @@ -63,7 +64,7 @@ pub struct ShardQueuer { /// The shards that are queued for booting. /// /// This will typically be filled with previously failed boots. - pub queue: VecDeque, + pub queue: VecDeque, /// A copy of the map of shard runners. pub runners: Arc>>, /// A receiver channel for the shard queuer to be told to start shards. @@ -72,7 +73,9 @@ pub struct ShardQueuer { #[cfg(feature = "voice")] pub voice_manager: Option>, /// A copy of the URL to use to connect to the gateway. - pub ws_url: Arc>, + pub ws_url: Arc, + /// The total amount of shards to start. + pub shard_total: NonZeroU16, #[cfg(feature = "cache")] pub cache: Arc, pub http: Arc, @@ -116,14 +119,16 @@ impl ShardQueuer { debug!("[Shard Queuer] Received to shutdown shard {} with {}.", shard.0, code); self.shutdown(shard, code).await; }, - Ok(Some(ShardQueuerMessage::Start(id, total))) => { - debug!("[Shard Queuer] Received to start shard {} of {}.", id.0, total.0); - self.checked_start(id, total.0).await; + Ok(Some(ShardQueuerMessage::Start(shard_id))) => { + self.checked_start(shard_id).await; + }, + Ok(Some(ShardQueuerMessage::SetShardTotal(shard_total))) => { + self.shard_total = shard_total; }, Ok(None) => break, Err(_) => { if let Some(shard) = self.queue.pop_front() { - self.checked_start(shard.id, shard.total).await; + self.checked_start(shard).await; } }, } @@ -148,28 +153,26 @@ impl ShardQueuer { } #[instrument(skip(self))] - async fn checked_start(&mut self, id: ShardId, total: u32) { - debug!("[Shard Queuer] Checked start for shard {} out of {}", id, total); - self.check_last_start().await; + async fn checked_start(&mut self, shard_id: ShardId) { + debug!("[Shard Queuer] Checked start for shard {shard_id}"); - if let Err(why) = self.start(id, total).await { - warn!("[Shard Queuer] Err starting shard {}: {:?}", id, why); - info!("[Shard Queuer] Re-queueing start of shard {}", id); + self.check_last_start().await; + if let Err(why) = self.start(shard_id).await { + warn!("[Shard Queuer] Err starting shard {shard_id}: {why:?}"); + info!("[Shard Queuer] Re-queueing start of shard {shard_id}"); - self.queue.push_back(ShardInfo::new(id, total)); + self.queue.push_back(shard_id); } self.last_start = Some(Instant::now()); } #[instrument(skip(self))] - async fn start(&mut self, id: ShardId, total: u32) -> Result<()> { - let shard_info = ShardInfo::new(id, total); - + async fn start(&mut self, shard_id: ShardId) -> Result<()> { let mut shard = Shard::new( Arc::clone(&self.ws_url), self.http.token(), - shard_info, + ShardInfo::new(shard_id, self.shard_total), self.intents, self.presence.clone(), ) @@ -204,7 +207,7 @@ impl ShardQueuer { debug!("[ShardRunner {:?}] Stopping", runner.shard.shard_info()); }); - self.runners.lock().await.insert(id, runner_info); + self.runners.lock().await.insert(shard_id, runner_info); Ok(()) } diff --git a/src/gateway/bridge/shard_runner.rs b/src/gateway/bridge/shard_runner.rs index b349ee61ed1..26d5dddb8b6 100644 --- a/src/gateway/bridge/shard_runner.rs +++ b/src/gateway/bridge/shard_runner.rs @@ -331,7 +331,9 @@ impl ShardRunner { }, Event::VoiceServerUpdate(event) => { if let Some(guild_id) = event.guild_id { - voice_manager.server_update(guild_id, &event.endpoint, &event.token).await; + voice_manager + .server_update(guild_id, event.endpoint.as_deref(), &event.token) + .await; } }, Event::VoiceStateUpdate(event) => { diff --git a/src/gateway/bridge/voice.rs b/src/gateway/bridge/voice.rs index 7f03113070f..8f3bc6ce7c2 100644 --- a/src/gateway/bridge/voice.rs +++ b/src/gateway/bridge/voice.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroU16; + use async_trait::async_trait; use futures::channel::mpsc::UnboundedSender as Sender; @@ -14,7 +16,7 @@ pub trait VoiceGatewayManager: Send + Sync { /// Performs initial setup at the start of a connection to Discord. /// /// This will only occur once, and provides the bot's ID and shard count. - async fn initialise(&self, shard_count: u32, user_id: UserId); + async fn initialise(&self, shard_count: NonZeroU16, user_id: UserId); /// Handler fired in response to a [`Ready`] event. /// @@ -22,19 +24,19 @@ pub trait VoiceGatewayManager: Send + Sync { /// active shard. /// /// [`Ready`]: crate::model::event::Event - async fn register_shard(&self, shard_id: u32, sender: Sender); + async fn register_shard(&self, shard_id: u16, sender: Sender); /// Handler fired in response to a disconnect, reconnection, or rebalance. /// /// This event invalidates the last sender associated with `shard_id`. Unless the bot is fully /// disconnecting, this is often followed by a call to [`Self::register_shard`]. Users may wish /// to buffer manually any gateway messages sent between these calls. - async fn deregister_shard(&self, shard_id: u32); + async fn deregister_shard(&self, shard_id: u16); /// Handler for VOICE_SERVER_UPDATE messages. /// /// These contain the endpoint and token needed to form a voice connection session. - async fn server_update(&self, guild_id: GuildId, endpoint: &Option, token: &str); + async fn server_update(&self, guild_id: GuildId, endpoint: Option<&str>, token: &str); /// Handler for VOICE_STATE_UPDATE messages. /// diff --git a/src/gateway/shard.rs b/src/gateway/shard.rs index a2545bdf8e9..a0a582afdeb 100644 --- a/src/gateway/shard.rs +++ b/src/gateway/shard.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use std::time::{Duration as StdDuration, Instant}; -use tokio::sync::Mutex; use tokio_tungstenite::tungstenite::error::Error as TungsteniteError; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; use tracing::{debug, error, info, instrument, trace, warn}; @@ -73,7 +72,7 @@ pub struct Shard { // a decent amount of time. pub started: Instant, pub token: String, - ws_url: Arc>, + ws_url: Arc, pub intents: GatewayIntents, } @@ -87,6 +86,7 @@ impl Shard { /// Instantiating a new Shard manually for a bot with no shards, and then listening for events: /// /// ```rust,no_run + /// use std::num::NonZeroU16; /// use std::sync::Arc; /// /// use serenity::gateway::Shard; @@ -101,11 +101,11 @@ impl Shard { /// let token = std::env::var("DISCORD_BOT_TOKEN")?; /// let shard_info = ShardInfo { /// id: ShardId(0), - /// total: 1, + /// total: NonZeroU16::MIN, /// }; /// /// // retrieve the gateway response, which contains the URL to connect to - /// let gateway = Arc::new(Mutex::new(http.get_gateway().await?.url)); + /// let gateway = Arc::from(http.get_gateway().await?.url); /// let shard = Shard::new(gateway, &token, shard_info, GatewayIntents::all(), None).await?; /// /// // at this point, you can create a `loop`, and receive events and match @@ -119,14 +119,13 @@ impl Shard { /// On Error, will return either [`Error::Gateway`], [`Error::Tungstenite`] or a Rustls/native /// TLS error. pub async fn new( - ws_url: Arc>, + ws_url: Arc, token: &str, shard_info: ShardInfo, intents: GatewayIntents, presence: Option, ) -> Result { - let url = ws_url.lock().await.clone(); - let client = connect(&url).await?; + let client = connect(&ws_url).await?; let presence = presence.unwrap_or_default(); let last_heartbeat_sent = None; @@ -595,24 +594,19 @@ impl Shard { /// specifying a query parameter: /// /// ```rust,no_run - /// # use tokio::sync::Mutex; /// # use serenity::gateway::{ChunkGuildFilter, Shard}; - /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; - /// # use serenity::model::id::ShardId; - /// # use std::sync::Arc; - /// # - /// # async fn run() -> Result<(), Box> { - /// # let mutex = Arc::new(Mutex::new("".to_string())); - /// # let shard_info = ShardInfo { - /// # id: ShardId(0), - /// # total: 1, - /// # }; - /// # - /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all(), None).await?; - /// # + /// # async fn run(mut shard: Shard) -> Result<(), Box> { /// use serenity::model::id::GuildId; /// - /// shard.chunk_guild(GuildId::new(81384788765712384), Some(2000), false, ChunkGuildFilter::None, None).await?; + /// shard + /// .chunk_guild( + /// GuildId::new(81384788765712384), + /// Some(2000), + /// false, + /// ChunkGuildFilter::None, + /// None, + /// ) + /// .await?; /// # Ok(()) /// # } /// ``` @@ -621,22 +615,8 @@ impl Shard { /// `"do"` and a nonce of `"request"`: /// /// ```rust,no_run - /// # use tokio::sync::Mutex; - /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; /// # use serenity::gateway::{ChunkGuildFilter, Shard}; - /// # use serenity::model::id::ShardId; - /// # use std::error::Error; - /// # use std::sync::Arc; - /// # - /// # async fn run() -> Result<(), Box> { - /// # let mutex = Arc::new(Mutex::new("".to_string())); - /// # - /// # let shard_info = ShardInfo { - /// # id: ShardId(0), - /// # total: 1, - /// # }; - /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all(), None).await?; - /// # + /// # async fn run(mut shard: Shard) -> Result<(), Box> { /// use serenity::model::id::GuildId; /// /// shard @@ -702,8 +682,7 @@ impl Shard { // Hello is received. self.stage = ConnectionStage::Connecting; self.started = Instant::now(); - let url = &self.ws_url.lock().await.clone(); - let client = connect(url).await?; + let client = connect(&self.ws_url).await?; self.stage = ConnectionStage::Handshake; Ok(client) diff --git a/src/model/gateway.rs b/src/model/gateway.rs index 16b35669460..15ca1a80c94 100644 --- a/src/model/gateway.rs +++ b/src/model/gateway.rs @@ -22,7 +22,7 @@ pub struct BotGateway { /// The gateway to connect to. pub url: String, /// The number of shards that is recommended to be used by the current bot user. - pub shards: u32, + pub shards: NonZeroU16, /// Information describing how many gateway sessions you can initiate within a ratelimit /// period. pub session_start_limit: SessionStartLimit, @@ -372,12 +372,12 @@ pub struct SessionStartLimit { #[derive(Clone, Copy, Debug)] pub struct ShardInfo { pub id: ShardId, - pub total: u32, + pub total: NonZeroU16, } impl ShardInfo { #[must_use] - pub(crate) fn new(id: ShardId, total: u32) -> Self { + pub(crate) fn new(id: ShardId, total: NonZeroU16) -> Self { Self { id, total, @@ -387,7 +387,7 @@ impl ShardInfo { impl<'de> serde::Deserialize<'de> for ShardInfo { fn deserialize>(deserializer: D) -> StdResult { - <(u32, u32)>::deserialize(deserializer).map(|(id, total)| ShardInfo { + <(u16, NonZeroU16)>::deserialize(deserializer).map(|(id, total)| ShardInfo { id: ShardId(id), total, }) diff --git a/src/model/guild/guild_id.rs b/src/model/guild/guild_id.rs index e2fa2352c94..10d0cf78716 100644 --- a/src/model/guild/guild_id.rs +++ b/src/model/guild/guild_id.rs @@ -1345,8 +1345,8 @@ impl GuildId { #[cfg(all(feature = "cache", feature = "utils"))] #[inline] #[must_use] - pub fn shard_id(self, cache: impl AsRef) -> u32 { - crate::utils::shard_id(self, cache.as_ref().shard_count()) + pub fn shard_id(self, cache: impl AsRef) -> u16 { + crate::utils::shard_id(self, cache.as_ref().shard_count().get()) } /// Returns the Id of the shard associated with the guild. @@ -1371,7 +1371,7 @@ impl GuildId { #[cfg(all(feature = "utils", not(feature = "cache")))] #[inline] #[must_use] - pub fn shard_id(self, shard_count: u32) -> u32 { + pub fn shard_id(self, shard_count: u16) -> u16 { crate::utils::shard_id(self, shard_count) } diff --git a/src/model/guild/mod.rs b/src/model/guild/mod.rs index daac25a4b9c..4e18a4147d3 100644 --- a/src/model/guild/mod.rs +++ b/src/model/guild/mod.rs @@ -2150,7 +2150,7 @@ impl Guild { /// [`utils::shard_id`]: crate::utils::shard_id #[cfg(all(feature = "cache", feature = "utils"))] #[inline] - pub fn shard_id(&self, cache: impl AsRef) -> u32 { + pub fn shard_id(&self, cache: impl AsRef) -> u16 { self.id.shard_id(&cache) } @@ -2175,7 +2175,7 @@ impl Guild { #[cfg(all(feature = "utils", not(feature = "cache")))] #[inline] #[must_use] - pub fn shard_id(&self, shard_count: u32) -> u32 { + pub fn shard_id(&self, shard_count: u16) -> u16 { self.id.shard_id(shard_count) } diff --git a/src/model/guild/partial_guild.rs b/src/model/guild/partial_guild.rs index 6205bb449c7..f85fa214778 100644 --- a/src/model/guild/partial_guild.rs +++ b/src/model/guild/partial_guild.rs @@ -1387,7 +1387,7 @@ impl PartialGuild { #[cfg(all(feature = "cache", feature = "utils"))] #[inline] #[must_use] - pub fn shard_id(&self, cache: impl AsRef) -> u32 { + pub fn shard_id(&self, cache: impl AsRef) -> u16 { self.id.shard_id(cache) } @@ -1412,7 +1412,7 @@ impl PartialGuild { #[cfg(all(feature = "utils", not(feature = "cache")))] #[inline] #[must_use] - pub fn shard_id(&self, shard_count: u32) -> u32 { + pub fn shard_id(&self, shard_count: u16) -> u16 { self.id.shard_id(shard_count) } diff --git a/src/model/id.rs b/src/model/id.rs index d046453de87..d7d7d681797 100644 --- a/src/model/id.rs +++ b/src/model/id.rs @@ -287,7 +287,7 @@ id_u64! { /// and therefore cannot be [`Serialize`]d or [`Deserialize`]d. #[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))] #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] -pub struct ShardId(pub u32); +pub struct ShardId(pub u16); impl fmt::Display for ShardId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/src/model/invite.rs b/src/model/invite.rs index db69ea32d8a..ab2b09bf3d0 100644 --- a/src/model/invite.rs +++ b/src/model/invite.rs @@ -236,7 +236,7 @@ impl InviteGuild { #[cfg(all(feature = "cache", feature = "utils"))] #[inline] #[must_use] - pub fn shard_id(&self, cache: impl AsRef) -> u32 { + pub fn shard_id(&self, cache: impl AsRef) -> u16 { self.id.shard_id(&cache) } @@ -261,7 +261,7 @@ impl InviteGuild { #[cfg(all(feature = "utils", not(feature = "cache")))] #[inline] #[must_use] - pub fn shard_id(&self, shard_count: u32) -> u32 { + pub fn shard_id(&self, shard_count: u16) -> u16 { self.id.shard_id(shard_count) } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 367c28bd9f8..3950f6fe270 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -22,6 +22,7 @@ pub use content_safe::*; pub use formatted_timestamp::*; #[cfg(feature = "collector")] pub use quick_modal::*; +use tracing::warn; use url::Url; pub use self::custom_message::CustomMessage; @@ -485,8 +486,16 @@ pub(crate) fn user_perms(cache: impl AsRef, channel_id: ChannelId) -> Res /// ``` #[inline] #[must_use] -pub fn shard_id(guild_id: GuildId, shard_count: u32) -> u32 { - ((guild_id.get() >> 22) % (shard_count as u64)) as u32 +pub fn shard_id(guild_id: GuildId, shard_count: u16) -> u16 { + let shard_count = check_shard_total(shard_count); + ((guild_id.get() >> 22) % (shard_count.get() as u64)) as u16 +} + +pub(crate) fn check_shard_total(total_shards: u16) -> NonZeroU16 { + NonZeroU16::new(total_shards).unwrap_or_else(|| { + warn!("Invalid shard total provided ({total_shards}), defaulting to 1"); + NonZeroU16::MIN + }) } #[cfg(test)]