From b213ce545dc2f932b1a8e91a35ce410c8d42b697 Mon Sep 17 00:00:00 2001 From: Joe L <56809242+jo3-l@users.noreply.github.com> Date: Fri, 21 Jun 2024 08:42:01 -0700 Subject: [PATCH] reminders: verify that member can run `remindme` in target channel (#1674) * commands: move code handling cmd overrides for threads into YagCommand.GetSettings Previously this logic was done by the caller, but it should be the responsibility of the GetSettings procedure. This commit also incidentally fixes a minor bug in CommonContainerNotFoundHandler, which previously did not consider command overrides properly if used in a thread. * reminders: verify that member can run `remindme` in target channel --- commands/util.go | 3 +-- commands/yagcommmand.go | 37 +++++++++++++++++-------------------- reminders/plugin_bot.go | 37 ++++++++++++++++++++++++++++++++++--- 3 files changed, 52 insertions(+), 25 deletions(-) diff --git a/commands/util.go b/commands/util.go index 6cc34500eb..5c4685fa04 100644 --- a/commands/util.go +++ b/commands/util.go @@ -154,10 +154,9 @@ func CommonContainerNotFoundHandler(container *dcmd.Container, fixedMessage stri return func(data *dcmd.Data) (interface{}, error) { // Only show stuff if atleast 1 of the commands in the container is enabled if data.GuildData != nil { - cParentID := data.GuildData.CS.ParentID ms := data.GuildData.MS - channelOverrides, err := GetOverridesForChannel(data.ChannelID, cParentID, data.GuildData.GS.ID) + channelOverrides, err := GetOverridesForChannel(data.GuildData.CS, data.GuildData.GS) if err != nil { logger.WithError(err).WithField("guild", data.GuildData.GS.ID).Error("failed retrieving command overrides") return nil, nil diff --git a/commands/yagcommmand.go b/commands/yagcommmand.go index 4128ae281c..3bf1127c48 100644 --- a/commands/yagcommmand.go +++ b/commands/yagcommmand.go @@ -437,16 +437,8 @@ func (yc *YAGCommand) checkCanExecuteCommand(data *dcmd.Data) (canExecute bool, return false, nil, nil, nil } } - channel_id := data.GuildData.CS.ID - parent_id := data.GuildData.CS.ParentID - // in case the channel is a thread, get the parent channel from parent id and check for the overrides - if data.GuildData.CS.Type.IsThread() { - channel := data.GuildData.GS.GetChannel(parent_id) - channel_id = channel.ID - parent_id = channel.ParentID - } - settings, err = yc.GetSettings(data.ContainerChain, channel_id, parent_id, guild.ID) + settings, err = yc.GetSettings(data.ContainerChain, data.GuildData.CS, guild) if err != nil { resp = &CanExecuteError{ Type: ReasonError, @@ -721,9 +713,14 @@ type CommandSettings struct { IgnoreRoles []int64 } -func GetOverridesForChannel(channelID, channelParentID, guildID int64) ([]*models.CommandsChannelsOverride, error) { +func GetOverridesForChannel(cs *dstate.ChannelState, guild *dstate.GuildSet) ([]*models.CommandsChannelsOverride, error) { + if cs.Type.IsThread() { + // Look for overrides from the parent channel, not the thread. + cs = guild.GetChannel(cs.ParentID) + } + // Fetch the overrides from the database, we treat the global settings as an override for simplicity - channelOverrides, err := models.CommandsChannelsOverrides(qm.Where("(? = ANY (channels) OR global=true OR ? = ANY (channel_categories)) AND guild_id=?", channelID, channelParentID, guildID), qm.Load("CommandsCommandOverrides")).AllG(context.Background()) + channelOverrides, err := models.CommandsChannelsOverrides(qm.Where("(? = ANY (channels) OR global=true OR ? = ANY (channel_categories)) AND guild_id=?", cs.ID, cs.ParentID, guild.ID), qm.Load("CommandsCommandOverrides")).AllG(context.Background()) if err != nil { return nil, err } @@ -732,23 +729,23 @@ func GetOverridesForChannel(channelID, channelParentID, guildID int64) ([]*model } // GetSettings returns the settings from the command, generated from the servers channel and command overrides -func (cs *YAGCommand) GetSettings(containerChain []*dcmd.Container, channelID, channelParentID, guildID int64) (settings *CommandSettings, err error) { +func (yc *YAGCommand) GetSettings(containerChain []*dcmd.Container, cs *dstate.ChannelState, guild *dstate.GuildSet) (settings *CommandSettings, err error) { // Fetch the overrides from the database, we treat the global settings as an override for simplicity - channelOverrides, err := GetOverridesForChannel(channelID, channelParentID, guildID) + channelOverrides, err := GetOverridesForChannel(cs, guild) if err != nil { err = errors.WithMessage(err, "GetOverridesForChannel") return } - return cs.GetSettingsWithLoadedOverrides(containerChain, guildID, channelOverrides) + return yc.GetSettingsWithLoadedOverrides(containerChain, guild.ID, channelOverrides) } -func (cs *YAGCommand) GetSettingsWithLoadedOverrides(containerChain []*dcmd.Container, guildID int64, channelOverrides []*models.CommandsChannelsOverride) (settings *CommandSettings, err error) { +func (yc *YAGCommand) GetSettingsWithLoadedOverrides(containerChain []*dcmd.Container, guildID int64, channelOverrides []*models.CommandsChannelsOverride) (settings *CommandSettings, err error) { settings = &CommandSettings{} // Some commands have custom places to toggle their enabled status - ce, err := cs.customEnabled(guildID) + ce, err := yc.customEnabled(guildID) if err != nil { err = errors.WithMessage(err, "customEnabled") return @@ -758,7 +755,7 @@ func (cs *YAGCommand) GetSettingsWithLoadedOverrides(containerChain []*dcmd.Cont return } - if cs.HideFromCommandsPage { + if yc.HideFromCommandsPage { settings.Enabled = true return } @@ -780,7 +777,7 @@ func (cs *YAGCommand) GetSettingsWithLoadedOverrides(containerChain []*dcmd.Cont } } - cmdFullName := cs.Name + cmdFullName := yc.Name if len(containerChain) > 1 { lastContainer := containerChain[len(containerChain)-1] cmdFullName = lastContainer.Names[0] + " " + cmdFullName @@ -788,12 +785,12 @@ func (cs *YAGCommand) GetSettingsWithLoadedOverrides(containerChain []*dcmd.Cont // Assign the global settings, if existing if global != nil { - cs.fillSettings(cmdFullName, global, settings) + yc.fillSettings(cmdFullName, global, settings) } // Assign the channel override, if existing if channelOverride != nil { - cs.fillSettings(cmdFullName, channelOverride, settings) + yc.fillSettings(cmdFullName, channelOverride, settings) } return diff --git a/reminders/plugin_bot.go b/reminders/plugin_bot.go index 354237b70a..8743500af9 100644 --- a/reminders/plugin_bot.go +++ b/reminders/plugin_bot.go @@ -76,15 +76,37 @@ var cmds = []*commands.YAGCommand{ id := parsed.ChannelID if c := parsed.Switch("channel"); c.Value != nil { - id = c.Value.(*dstate.ChannelState).ID + cs := c.Value.(*dstate.ChannelState) + mention, _ := cs.Mention() - hasPerms, err := bot.AdminOrPermMS(parsed.GuildData.GS.ID, id, parsed.GuildData.MS, discordgo.PermissionSendMessages|discordgo.PermissionViewChannel) + hasPerms, err := bot.AdminOrPermMS(parsed.GuildData.GS.ID, cs.ID, parsed.GuildData.MS, discordgo.PermissionSendMessages|discordgo.PermissionViewChannel) if err != nil { return "Failed checking permissions; please try again or join the support server.", err } if !hasPerms { - return "You do not have permissions to send messages there", nil + return fmt.Sprintf("You do not have permissions to send messages in %s", mention), nil + } + + // Ensure the member can run the `remindme` command in the + // target channel according to the configured command and + // channel overrides, if any. + yc := parsed.Cmd.Command.(*commands.YAGCommand) + settings, err := yc.GetSettings(parsed.ContainerChain, cs, parsed.GuildData.GS) + if err != nil { + return "Failed fetching command settings", err + } + + if !settings.Enabled { + return fmt.Sprintf("The `remindme` command is disabled in %s", mention), nil + } + + ms := parsed.GuildData.MS + // If there are no required roles set, the member should be allowed to run the command. + hasRequiredRoles := len(settings.RequiredRoles) == 0 || memberHasAnyRole(ms, settings.RequiredRoles) + hasIgnoredRoles := memberHasAnyRole(ms, settings.IgnoreRoles) + if !hasRequiredRoles || hasIgnoredRoles { + return fmt.Sprintf("You cannot use the `remindme` command in %s", mention), nil } } @@ -230,6 +252,15 @@ var cmds = []*commands.YAGCommand{ }, } +func memberHasAnyRole(ms *dstate.MemberState, roles []int64) bool { + for _, r := range ms.Member.Roles { + if common.ContainsInt64Slice(roles, r) { + return true + } + } + return false +} + func checkUserScheduledEvent(evt *seventsmodels.ScheduledEvent, data interface{}) (retry bool, err error) { // IMPORTANT: evt.GuildID can be 1 in cases where it was migrated from the // legacy scheduled event system.