diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 40a17622..051e6a00 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -59,6 +59,7 @@ type BridgeConfig struct { CommandPrefix string `yaml:"command_prefix"` PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + AsyncEvents bool `yaml:"async_events"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 7e524e84..d6ccf007 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -25,6 +25,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "bridge", "command_prefix") helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") helper.Copy(up.Bool, "bridge", "private_chat_portal_meta") + helper.Copy(up.Bool, "bridge", "async_events") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "tag_only_on_create") helper.Copy(up.Bool, "bridge", "mute_only_on_create") diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 48af69ae..8df4ae22 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -289,7 +289,8 @@ func (as *ASIntent) UploadMediaStream( var res *bridgev2.FileStreamResult res, err = cb(tempFile) if err != nil { - err = fmt.Errorf("failed to write to temp file: %w", err) + err = fmt.Errorf("write callback failed: %w", err) + return } var replFile *os.File if res.ReplacementFile != "" { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 9c7ea65a..c738e235 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -7,6 +7,9 @@ bridge: # Whether the bridge should set names and avatars explicitly for DM portals. # This is only necessary when using clients that don't support MSC4171. private_chat_portal_meta: false + # Should events be handled asynchronously within portal rooms? + # If true, events may end up being out of order, but slow events won't block other ones. + async_events: false # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index c1f02890..00f5eb72 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -54,6 +54,11 @@ type ProvisioningAPI struct { // GetAuthFromRequest is a custom function for getting the auth token from // the request if the Authorization header is not present. GetAuthFromRequest func(r *http.Request) string + + // GetUserIDFromRequest is a custom function for getting the user ID to + // authenticate as instead of using the user ID provided in the query + // parameter. + GetUserIDFromRequest func(r *http.Request) id.UserID } type ProvLogin struct { @@ -200,6 +205,9 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return } userID := id.UserID(r.URL.Query().Get("user_id")) + if userID == "" && prov.GetUserIDFromRequest != nil { + userID = prov.GetUserIDFromRequest(r) + } if auth != prov.br.Config.Provisioning.SharedSecret { var err error if strings.HasPrefix(auth, "openid:") { @@ -227,6 +235,14 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return } // TODO handle user being nil? + // TODO per-endpoint permissions? + if !user.Permissions.Login { + jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ + Err: "User does not have login permissions", + ErrCode: mautrix.MForbidden.ErrCode, + }) + return + } ctx := context.WithValue(r.Context(), provisioningUserKey, user) if loginID, ok := mux.Vars(r)["loginProcessID"]; ok { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 556b7407..3e0617ae 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -273,11 +273,16 @@ type MaxFileSizeingNetwork interface { SetMaxFileSize(maxSize int64) } +type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error) + type MatrixMessageResponse struct { DB *database.Message - - Pending networkid.TransactionID - HandleEcho func(RemoteMessage, *database.Message) (bool, error) + // If Pending is set, the bridge will not save the provided message to the database. + // This should only be used if AddPendingToSave has been called. + Pending bool + // If RemovePending is set, the bridge will remove the provided transaction ID from pending messages + // after saving the provided message to the database. This should be used with AddPendingToIgnore. + RemovePending networkid.TransactionID } type FileRestriction struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 4ba7a8f6..ca3be3e4 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -14,6 +14,7 @@ import ( "runtime/debug" "strings" "sync" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -59,6 +60,7 @@ type portalEvent interface { type outgoingMessage struct { db *database.Message evt *event.Event + ignore bool handle func(RemoteMessage, *database.Message) (bool, error) } @@ -274,23 +276,49 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { func (portal *Portal) eventLoop() { for rawEvt := range portal.events { - switch evt := rawEvt.(type) { - case *portalMatrixEvent: - portal.handleMatrixEvent(evt.sender, evt.evt) - case *portalRemoteEvent: - portal.handleRemoteEvent(evt.source, evt.evt) - case *portalCreateEvent: - portal.handleCreateEvent(evt) - default: - panic(fmt.Errorf("illegal type %T in eventLoop", evt)) + portal.handleSingleEventAsync(rawEvt) + } +} + +func (portal *Portal) handleSingleEventAsync(rawEvt any) { + log := portal.Log.With().Logger() + if _, isCreate := rawEvt.(*portalCreateEvent); isCreate { + portal.handleSingleEvent(&log, rawEvt, func() {}) + } else if portal.Bridge.Config.AsyncEvents { + go portal.handleSingleEvent(&log, rawEvt, func() {}) + } else { + doneCh := make(chan struct{}) + var backgrounded atomic.Bool + go portal.handleSingleEvent(&log, rawEvt, func() { + close(doneCh) + if backgrounded.Load() { + log.Debug().Msg("Event that took too long finally finished handling") + } + }) + tick := time.NewTicker(30 * time.Second) + defer tick.Stop() + for i := 0; i < 10; i++ { + select { + case <-doneCh: + if i > 0 { + log.Debug().Msg("Event that took long finished handling") + } + return + case <-tick.C: + log.Warn().Msg("Event handling is taking long") + } } + log.Warn().Msg("Event handling is taking too long, continuing in background") + backgrounded.Store(true) } } -func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) { +func (portal *Portal) handleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) { + ctx := log.WithContext(context.Background()) defer func() { + doneCallback() if err := recover(); err != nil { - logEvt := zerolog.Ctx(evt.ctx).Error() + logEvt := log.Error() if realErr, ok := err.(error); ok { logEvt = logEvt.Err(realErr) } else { @@ -299,10 +327,36 @@ func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) { logEvt. Bytes("stack", debug.Stack()). Msg("Portal creation panicked") - evt.cb(fmt.Errorf("portal creation panicked")) + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + if evt.evt.ID != "" { + go portal.sendErrorStatus(ctx, evt.evt, ErrPanicInEventHandler) + } + case *portalCreateEvent: + evt.cb(fmt.Errorf("portal creation panicked")) + } } }() - evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil)) + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("action", "handle matrix event"). + Stringer("event_id", evt.evt.ID). + Str("event_type", evt.evt.Type.Type) + }) + portal.handleMatrixEvent(ctx, evt.sender, evt.evt) + case *portalRemoteEvent: + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("action", "handle remote event"). + Str("source_id", string(evt.source.ID)) + }) + portal.handleRemoteEvent(ctx, evt.source, evt.evt) + case *portalCreateEvent: + *log = *zerolog.Ctx(evt.ctx) + evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil)) + default: + panic(fmt.Errorf("illegal type %T in eventLoop", evt)) + } } func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { @@ -392,29 +446,8 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, return false } -func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { - log := portal.Log.With(). - Str("action", "handle matrix event"). - Stringer("event_id", evt.ID). - Str("event_type", evt.Type.Type). - Logger() - ctx := log.WithContext(context.TODO()) - defer func() { - if err := recover(); err != nil { - logEvt := log.Error() - if realErr, ok := err.(error); ok { - logEvt = logEvt.Err(realErr) - } else { - logEvt = logEvt.Any(zerolog.ErrorFieldName, err) - } - logEvt. - Bytes("stack", debug.Stack()). - Msg("Matrix event handler panicked") - if evt.ID != "" { - go portal.sendErrorStatus(ctx, evt, ErrPanicInEventHandler) - } - } - }() +func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) { + log := zerolog.Ctx(ctx) if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { switch evt.Type { case event.EphemeralEventReceipt: @@ -775,7 +808,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } - resp, err := sender.Client.HandleMatrixMessage(ctx, &MatrixMessage{ + wrappedEvt := &MatrixMessage{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ Event: evt, Content: content, @@ -784,52 +817,30 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin }, ThreadRoot: threadRoot, ReplyTo: replyTo, - }) + } + resp, err := sender.Client.HandleMatrixMessage(ctx, wrappedEvt) if err != nil { log.Err(err).Msg("Failed to handle Matrix message") portal.sendErrorStatus(ctx, evt, err) return } - message := resp.DB - if message.MXID == "" { - message.MXID = evt.ID - } - if message.Room.ID == "" { - message.Room = portal.PortalKey - } - if message.Timestamp.IsZero() { - message.Timestamp = time.UnixMilli(evt.Timestamp) - } - if message.ReplyTo.MessageID == "" && replyTo != nil { - message.ReplyTo.MessageID = replyTo.ID - message.ReplyTo.PartID = &replyTo.PartID - } - if message.ThreadRoot == "" && threadRoot != nil { - message.ThreadRoot = threadRoot.ID - if threadRoot.ThreadRoot != "" { - message.ThreadRoot = threadRoot.ThreadRoot - } - } - if message.SenderMXID == "" { - message.SenderMXID = evt.Sender - } - if resp.Pending != "" { - // TODO if the event queue is ever removed, this will have to be done by the network connector before sending the request - // (for now this is fine because incoming messages will wait in the queue for this function to return) - portal.outgoingMessagesLock.Lock() - portal.outgoingMessages[resp.Pending] = outgoingMessage{ - db: message, - evt: evt, - handle: resp.HandleEcho, - } - portal.outgoingMessagesLock.Unlock() - } else { - // Hack to ensure the ghost row exists - // TODO move to better place (like login) - portal.Bridge.GetGhostByID(ctx, message.SenderID) - err = portal.Bridge.DB.Message.Insert(ctx, message) - if err != nil { - log.Err(err).Msg("Failed to save message to database") + message := wrappedEvt.fillDBMessage(resp.DB) + if !resp.Pending { + if resp.DB == nil { + log.Error().Msg("Network connector didn't return a message to save") + } else { + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, message.SenderID) + err = portal.Bridge.DB.Message.Insert(ctx, message) + if err != nil { + log.Err(err).Msg("Failed to save message to database") + } + if resp.RemovePending != "" { + portal.outgoingMessagesLock.Lock() + delete(portal.outgoingMessages, resp.RemovePending) + portal.outgoingMessagesLock.Unlock() + } } portal.sendSuccessStatus(ctx, evt) } @@ -846,6 +857,75 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } +// AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message. +// +// This should be used when the network connector will return the real message ID from HandleMatrixMessage. +// The [MatrixMessageResponse] should include RemovePending with the transaction ID sto remove it from the lit +// after saving to database. +// +// See also: [MatrixMessage.AddPendingToSave] +func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) { + evt.Portal.outgoingMessagesLock.Lock() + evt.Portal.outgoingMessages[txnID] = outgoingMessage{ + ignore: true, + } + evt.Portal.outgoingMessagesLock.Unlock() +} + +// AddPendingToSave adds a transaction ID that should be processed and pointed at the existing event if encountered. +// +// This should be used when the network connector returns `Pending: true` from HandleMatrixMessage, +// i.e. when the network connector does not know the message ID at the end of the handler. +// The [MatrixMessageResponse] should set Pending to true to prevent saving the returned message to the database. +// +// The provided function will be called when the message is encountered. +func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID networkid.TransactionID, handleEcho RemoteEchoHandler) { + evt.Portal.outgoingMessagesLock.Lock() + evt.Portal.outgoingMessages[txnID] = outgoingMessage{ + db: evt.fillDBMessage(message), + evt: evt.Event, + handle: handleEcho, + } + evt.Portal.outgoingMessagesLock.Unlock() +} + +// RemovePending removes a transaction ID from the list of pending messages. +// This should only be called if sending the message fails. +func (evt *MatrixMessage) RemovePending(txnID networkid.TransactionID) { + evt.Portal.outgoingMessagesLock.Lock() + delete(evt.Portal.outgoingMessages, txnID) + evt.Portal.outgoingMessagesLock.Unlock() +} + +func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Message { + if message == nil { + message = &database.Message{} + } + if message.MXID == "" { + message.MXID = evt.Event.ID + } + if message.Room.ID == "" { + message.Room = evt.Portal.PortalKey + } + if message.Timestamp.IsZero() { + message.Timestamp = time.UnixMilli(evt.Event.Timestamp) + } + if message.ReplyTo.MessageID == "" && evt.ReplyTo != nil { + message.ReplyTo.MessageID = evt.ReplyTo.ID + message.ReplyTo.PartID = &evt.ReplyTo.PartID + } + if message.ThreadRoot == "" && evt.ThreadRoot != nil { + message.ThreadRoot = evt.ThreadRoot.ID + if evt.ThreadRoot.ThreadRoot != "" { + message.ThreadRoot = evt.ThreadRoot.ThreadRoot + } + } + if message.SenderMXID == "" { + message.SenderMXID = evt.Event.Sender + } + return message +} + func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { log := zerolog.Ctx(ctx) editTargetID := content.RelatesTo.GetReplaceID() @@ -1410,11 +1490,8 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog portal.sendSuccessStatus(ctx, evt) } -func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { - log := portal.Log.With(). - Str("source_id", string(source.ID)). - Str("action", "handle remote event"). - Logger() +func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) { + log := zerolog.Ctx(ctx) defer func() { if err := recover(); err != nil { logEvt := log.Error() @@ -1433,7 +1510,6 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { c = c.Stringer("bridge_evt_type", evtType) return evt.AddLogContext(c) }) - ctx := log.WithContext(context.TODO()) if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) if !ok || !mcp.ShouldCreatePortal() { @@ -1715,6 +1791,8 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage pending, ok := portal.outgoingMessages[txnID] if !ok { return false, nil + } else if pending.ignore { + return true, nil } delete(portal.outgoingMessages, txnID) pending.db.ID = evt.GetID() @@ -1773,7 +1851,8 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, } if len(res.SubEvents) > 0 { for _, subEvt := range res.SubEvents { - portal.handleRemoteEvent(source, subEvt) + log := portal.Log.With().Str("source_id", string(source.ID)).Str("action", "handle remote subevent").Logger() + portal.handleRemoteEvent(log.WithContext(ctx), source, subEvt) } } return res.ContinueMessageHandling diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index ffe68ca5..e4a3e0ad 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -304,17 +304,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin partIDs = append(partIDs, part.ID) portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) - out.Events = append(out.Events, &event.Event{ - Sender: intent.GetMXID(), - Type: part.Type, - Timestamp: msg.Timestamp.UnixMilli(), - ID: evtID, - RoomID: portal.MXID, - Content: event.Content{ - Parsed: part.Content, - Raw: part.Extra, - }, - }) dbMessage := &database.Message{ ID: msg.ID, PartID: part.ID, @@ -327,6 +316,22 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin ReplyTo: ptr.Val(msg.ReplyTo), Metadata: part.DBMetadata, } + if part.DontBridge { + dbMessage.SetFakeMXID() + out.DBMessages = append(out.DBMessages, dbMessage) + continue + } + out.Events = append(out.Events, &event.Event{ + Sender: intent.GetMXID(), + Type: part.Type, + Timestamp: msg.Timestamp.UnixMilli(), + ID: evtID, + RoomID: portal.MXID, + Content: event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, + }) if firstPart == nil { firstPart = dbMessage } diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 1ee793a9..a4bd611a 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -37,8 +37,12 @@ func (portal *PortalInternals) EventLoop() { (*Portal)(portal).eventLoop() } -func (portal *PortalInternals) HandleCreateEvent(evt *portalCreateEvent) { - (*Portal)(portal).handleCreateEvent(evt) +func (portal *PortalInternals) HandleSingleEventAsync(rawEvt any) { + (*Portal)(portal).handleSingleEventAsync(rawEvt) +} + +func (portal *PortalInternals) HandleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) { + (*Portal)(portal).handleSingleEvent(log, rawEvt, doneCallback) } func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event) { @@ -53,8 +57,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i return (*Portal)(portal).checkConfusableName(ctx, userID, name) } -func (portal *PortalInternals) HandleMatrixEvent(sender *User, evt *event.Event) { - (*Portal)(portal).handleMatrixEvent(sender, evt) +func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) { + (*Portal)(portal).handleMatrixEvent(ctx, sender, evt) } func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) { @@ -109,8 +113,8 @@ func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleRemoteEvent(source *UserLogin, evt RemoteEvent) { - (*Portal)(portal).handleRemoteEvent(source, evt) +func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) { + (*Portal)(portal).handleRemoteEvent(ctx, source, evt) } func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { @@ -297,6 +301,10 @@ func (portal *PortalInternals) DoThreadBackfill(ctx context.Context, source *Use (*Portal)(portal).doThreadBackfill(ctx, source, threadID) } +func (portal *PortalInternals) CutoffMessages(ctx context.Context, messages []*BackfillMessage, aggressiveDedup, forward bool, lastMessage *database.Message) []*BackfillMessage { + return (*Portal)(portal).cutoffMessages(ctx, messages, aggressiveDedup, forward, lastMessage) +} + func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread) } diff --git a/client.go b/client.go index dda51fbc..4dfbc6b8 100644 --- a/client.go +++ b/client.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "os" + "slices" "strconv" "strings" "sync/atomic" @@ -20,7 +21,9 @@ import ( "github.com/rs/zerolog" "github.com/tidwall/gjson" + "go.mau.fi/util/ptr" "go.mau.fi/util/retryafter" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/event" @@ -322,7 +325,9 @@ func (cli *Client) RequestStart(req *http.Request) { func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) { var evt *zerolog.Event - if err != nil { + if errors.Is(err, context.Canceled) { + evt = zerolog.Ctx(req.Context()).Warn() + } else if err != nil { evt = zerolog.Ctx(req.Context()).Err(err) } else if handlerErr != nil { evt = zerolog.Ctx(req.Context()).Warn(). @@ -355,7 +360,9 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er if body := req.Context().Value(LogBodyContextKey); body != nil { evt.Interface("req_body", body) } - if err != nil { + if errors.Is(err, context.Canceled) { + evt.Msg("Request canceled") + } else if err != nil { evt.Msg("Request failed") } else if handlerErr != nil { evt.Msg("Request parsing failed") @@ -1498,21 +1505,19 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt Handler: parseRoomStateArray, }) if err == nil && cli.StateStore != nil { - clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching state") - } - for _, evts := range stateMap { + for evtType, evts := range stateMap { + if evtType == event.StateMember { + continue + } for _, evt := range evts { UpdateStateStore(ctx, cli.StateStore, evt) } } - clearErr = cli.StateStore.MarkMembersFetched(ctx, roomID) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Msg("Failed to mark members as fetched after fetching full room state") + updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, maps.Values(stateMap[event.StateMember])) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Msg("Failed to update members in state store after fetching members") } } return @@ -1864,24 +1869,26 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && cli.StateStore != nil { - clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching joined members") - } + fakeEvents := make([]*event.Event, len(resp.Joined)) + i := 0 for userID, member := range resp.Joined { - updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{ - Membership: event.MembershipJoin, - AvatarURL: id.ContentURIString(member.AvatarURL), - Displayname: member.DisplayName, - }) - if updateErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(updateErr). - Stringer("room_id", roomID). - Stringer("user_id", userID). - Msg("Failed to update membership in state store after fetching joined members") + fakeEvents[i] = &event.Event{ + StateKey: ptr.Ptr(userID.String()), + Type: event.StateMember, + RoomID: roomID, + Content: event.Content{Parsed: &event.MemberEventContent{ + Membership: event.MembershipJoin, + AvatarURL: id.ContentURIString(member.AvatarURL), + Displayname: member.DisplayName, + }}, } + i++ + } + updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, fakeEvents, event.MembershipJoin) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Msg("Failed to update members in state store after fetching joined members") } } return @@ -1910,27 +1917,20 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb } } if err == nil && cli.StateStore != nil { - var clearMemberships []event.Membership + var onlyMemberships []event.Membership if extra.Membership != "" { - clearMemberships = append(clearMemberships, extra.Membership) - } - if extra.NotMembership == "" { - clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching joined members") - } - } - for _, evt := range resp.Chunk { - UpdateStateStore(ctx, cli.StateStore, evt) + onlyMemberships = []event.Membership{extra.Membership} + } else if extra.NotMembership != "" { + onlyMemberships = []event.Membership{event.MembershipJoin, event.MembershipLeave, event.MembershipInvite, event.MembershipBan, event.MembershipKnock} + onlyMemberships = slices.DeleteFunc(onlyMemberships, func(m event.Membership) bool { + return m == extra.NotMembership + }) } - if extra.NotMembership == "" && extra.Membership == "" { - markErr := cli.StateStore.MarkMembersFetched(ctx, roomID) - if markErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(markErr). - Msg("Failed to mark members as fetched after fetching full member list") - } + updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, resp.Chunk, onlyMemberships...) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Msg("Failed to update members in state store after fetching members") } } return diff --git a/crypto/machine.go b/crypto/machine.go index 188aa210..85da2b3b 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -245,8 +245,22 @@ func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.Devic } } +func (mach *OlmMachine) otkCountIsForCrossSigningKey(otkCount *mautrix.OTKCount) bool { + if mach.crossSigningPubkeys == nil || otkCount.UserID != mach.Client.UserID { + return false + } + switch id.Ed25519(otkCount.DeviceID) { + case mach.crossSigningPubkeys.MasterKey, mach.crossSigningPubkeys.UserSigningKey, mach.crossSigningPubkeys.SelfSigningKey: + return true + } + return false +} + func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) { if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) { + if mach.otkCountIsForCrossSigningKey(otkCount) { + return + } // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions mach.Log.Warn(). Str("target_user_id", otkCount.UserID.String()). diff --git a/event/state.go b/event/state.go index 0844936a..6e5f0ae4 100644 --- a/event/state.go +++ b/event/state.go @@ -187,6 +187,7 @@ type PolicyRecommendation string const ( PolicyRecommendationBan PolicyRecommendation = "m.ban" PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban" + PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban" ) // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. diff --git a/federation/resolution.go b/federation/resolution.go index e6785988..24085282 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -11,6 +11,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "net/http" "net/url" @@ -140,7 +141,7 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (* return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) } var respData RespWellKnown - err = json.NewDecoder(resp.Body).Decode(&respData) + err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData) if err != nil { return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) } else if respData.Server == "" { diff --git a/go.mod b/go.mod index e3700339..78d1b8c4 100644 --- a/go.mod +++ b/go.mod @@ -8,18 +8,18 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.22 - github.com/rs/xid v1.5.0 + github.com/mattn/go-sqlite3 v1.14.23 + github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943 + go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.26.0 - golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa + golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 golang.org/x/net v0.28.0 golang.org/x/sync v0.8.0 gopkg.in/yaml.v3 v3.0.1 @@ -35,7 +35,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.24.0 // indirect - golang.org/x/text v0.17.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 1584f6a8..0f1a0558 100644 --- a/go.sum +++ b/go.sum @@ -24,15 +24,16 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= +github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw= github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= @@ -50,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943 h1:wdJ9XC/M6lVUrwDltHPodaA3SRJq+S+AzGEXdQ/o2AQ= -go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= +go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 h1:VZQlKBbeJ7KOlYSh6BnN5uWQTY/ypn/bJv0YyEd+pXc= +go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2/go.mod h1:WgYvbt9rVmoFeajP97NunQU7AjgvTPiNExN3oTHeePs= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= -golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI= -golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= @@ -66,10 +67,10 @@ golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= -golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go index cefe76d3..1779afa5 100644 --- a/hicli/database/statestore.go +++ b/hicli/database/statestore.go @@ -174,3 +174,7 @@ func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Roo } func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {} + +func (c *ClientStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + return nil +} diff --git a/requests.go b/requests.go index c49f7c9c..189e620d 100644 --- a/requests.go +++ b/requests.go @@ -83,6 +83,7 @@ type ReqLogin struct { Token string `json:"token,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` + RefreshToken bool `json:"refresh_token,omitempty"` // Whether or not the returned credentials should be stored in the Client StoreCredentials bool `json:"-"` diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 2cfd1b97..d594c307 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -19,6 +19,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/confusable" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exslices" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -194,21 +195,37 @@ func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, return err } +const insertUserProfileQuery = ` + INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (room_id, user_id) DO UPDATE + SET membership=excluded.membership, + displayname=excluded.displayname, + avatar_url=excluded.avatar_url, + name_skeleton=excluded.name_skeleton +` + +type userProfileRow struct { + UserID id.UserID + Membership event.Membership + Displayname string + AvatarURL id.ContentURIString + NameSkeleton []byte +} + +func (u *userProfileRow) GetMassInsertValues() [5]any { + return [5]any{u.UserID, u.Membership, u.Displayname, u.AvatarURL, u.NameSkeleton} +} + +var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)") + func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { var nameSkeleton []byte if !store.DisableNameDisambiguation && len(member.Displayname) > 0 { nameSkeletonArr := confusable.SkeletonHash(member.Displayname) nameSkeleton = nameSkeletonArr[:] } - _, err := store.Exec(ctx, ` - INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (room_id, user_id) DO UPDATE - SET membership=excluded.membership, - displayname=excluded.displayname, - avatar_url=excluded.avatar_url, - name_skeleton=excluded.name_skeleton - `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton) + _, err := store.Exec(ctx, insertUserProfileQuery, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton) return err } @@ -221,6 +238,50 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } +const userProfileMassInsertBatchSize = 500 + +func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + return store.DoTxn(ctx, nil, func(ctx context.Context) error { + err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...) + if err != nil { + return fmt.Errorf("failed to clear cached members: %w", err) + } + rows := make([]*userProfileRow, min(len(evts), userProfileMassInsertBatchSize)) + for _, evtsChunk := range exslices.Chunk(evts, userProfileMassInsertBatchSize) { + rows = rows[:0] + for _, evt := range evtsChunk { + content, ok := evt.Content.Parsed.(*event.MemberEventContent) + if !ok { + continue + } + row := &userProfileRow{ + UserID: id.UserID(*evt.StateKey), + Membership: content.Membership, + Displayname: content.Displayname, + AvatarURL: content.AvatarURL, + } + if !store.DisableNameDisambiguation && len(content.Displayname) > 0 { + nameSkeletonArr := confusable.SkeletonHash(content.Displayname) + row.NameSkeleton = nameSkeletonArr[:] + } + rows = append(rows, row) + } + query, args := userProfileMassInserter.Build([1]any{roomID}, rows) + _, err = store.Exec(ctx, query, args...) + if err != nil { + return fmt.Errorf("failed to insert members: %w", err) + } + } + if len(onlyMemberships) == 0 { + err = store.MarkMembersFetched(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to mark members as fetched: %w", err) + } + } + return nil + }) +} + func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { query := "DELETE FROM mx_user_profile WHERE room_id=$1" params := make([]any, len(memberships)+1) diff --git a/statestore.go b/statestore.go index 5f210e4f..e728b885 100644 --- a/statestore.go +++ b/statestore.go @@ -29,6 +29,7 @@ type StateStore interface { SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error + ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) @@ -270,9 +271,20 @@ func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id return nil } +func (store *MemoryStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + _ = store.ClearCachedMembers(ctx, roomID, onlyMemberships...) + for _, evt := range evts { + UpdateStateStore(ctx, store, evt) + } + if len(onlyMemberships) == 0 { + _ = store.MarkMembersFetched(ctx, roomID) + } + return nil +} + func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { - store.membersLock.Lock() - defer store.membersLock.Unlock() + store.membersLock.RLock() + defer store.membersLock.RUnlock() return maps.Clone(store.Members[roomID]), nil }