From d6ebcdef433eca5b0883843b0e5fd7d655837192 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 11 Aug 2023 15:39:44 +0100 Subject: [PATCH] Re-write state accumulator logic As per #211 - this combines Initialise and Accumulate into a single ProcessRoomEvents functions which can do snapshots / timelines. Implements the "brand new snapshot on create event" logic, which currently does not correctly invalidate caches. E2E tests pass, but integ tests are broken. --- state/accumulator.go | 277 ++++++++++++++++++++++++++++++++++++++ sync2/handler2/handler.go | 102 ++++++++++++++ sync2/poller.go | 52 +++---- sync2/poller_test.go | 3 + 4 files changed, 402 insertions(+), 32 deletions(-) diff --git a/state/accumulator.go b/state/accumulator.go index 1eb86c3b..96362b36 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -132,6 +132,283 @@ func (a *Accumulator) roomInfoDelta(roomID string, events []Event) RoomInfo { } } +type ProcessRoomEventResult struct { + BrandNewSnapshot bool + SnapshotID int64 + TimelineNIDs []int64 +} + +// ProcessRoomEvents creates state snapshots and inserts timeline sections into the DB. Returns: +// - The latest state snapshot NID +// - The new timeline events +// - Whether a brand new snapshot was created +// Downstream components need to refetch room state based on the state snapshot NID if a brand new snapshot was created. +// Otherwise, it can just consume the timeline events and roll forward state. +func (a *Accumulator) ProcessRoomEvents(userID, roomID string, timeline, state []json.RawMessage, prevBatch string) (*ProcessRoomEventResult, error) { + result := &ProcessRoomEventResult{} + // If there are no timeline events, do nothing. Note: even if there _are_ state events, such that we would + // generate a new state snapshot, we won't then subsequently _use_ this snapshot ID for any timeline event, + // meaning we'd end up with a dangling snapshot. This is why we don't care how long the state slice is. + // This is a weird case though so we will return an error/sentry it. + if len(timeline) == 0 { + if len(state) > 0 { + err := fmt.Errorf("processRoomEvents: %v received 0 timeline and %v state events for room %v which doesn't make sense, ignoring", userID, roomID, len(state)) + sentry.CaptureException(err) + return result, err + } + return result, nil + } + + err := sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error { + // The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly + // we expect: + // - there to be no duplicate events + // - if there are new events, they are always new. + // Both of these assumptions can be false for different reasons + dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch) + if err != nil { + return fmt.Errorf("filterAndParseTimelineEvents: %w", err) + } + if len(dedupedEvents) == 0 { + return nil // nothing to do + } + + // we know there are some events to insert which are new. We need to work out what snapshot + // to base these changes on. + + snapID, brandNew, err := a.calculateSnapshotForRoom(txn, roomID, state) + if err != nil { + return err + } + result.BrandNewSnapshot = brandNew + + // if we have just got a leave event for the polling user, and there is no snapshot for this room already, then + // we do NOT want to add this event to the events table, nor do we want to make a room snapshot. This is because + // this leave event is an invite rejection, rather than a normal event. Invite rejections cannot be processed in + // a normal way because we lack room state (no create event, PLs, etc). If we were to process the invite rejection, + // the room state would just be a single event: this leave event, which is wrong. + if len(dedupedEvents) == 1 && + dedupedEvents[0].Type == "m.room.member" && + (dedupedEvents[0].Membership == "leave" || dedupedEvents[0].Membership == "_leave") && + dedupedEvents[0].StateKey == userID && + snapID == 0 { + logger.Info().Str("event_id", dedupedEvents[0].ID).Str("room_id", roomID).Str("user_id", userID).Err(err).Msg( + "Accumulator: skipping processing of leave event, as no snapshot exists", + ) + return nil + } + + // Insert the event JSON + eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false) + if err != nil { + return fmt.Errorf("failed to insert events: %w", err) + } + if len(eventIDToNID) == 0 { + // nothing to do, we already know about these events + return nil + } + + // Set the NID values for each event now they are inserted + var timelineNIDs []int64 + var latestNID int64 + newEvents := make([]Event, 0, len(eventIDToNID)) + for _, ev := range dedupedEvents { + nid, ok := eventIDToNID[ev.ID] + if !ok { + continue // we must've seen this event which is why it's not in the map + } + ev.NID = int64(nid) + if gjson.GetBytes(ev.JSON, "state_key").Exists() { + // XXX: reusing this to mean "it's a state event" as well as "it's part of the state v2 response" + // its important that we don't insert 'ev' at this point as this should be False in the DB. + ev.IsState = true + } + // assign the highest nid value to the latest nid. + // we'll return this to the caller so they can stay in-sync + if ev.NID > latestNID { + latestNID = ev.NID + } + newEvents = append(newEvents, ev) + timelineNIDs = append(timelineNIDs, ev.NID) + } + + // Begin rolling forward state and calculating snapshots for each event + for _, ev := range newEvents { + var replacesNID int64 + // the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not, + // as this is the before snapshot ID. + beforeSnapID := snapID + + if ev.IsState { + // make a new snapshot and update the snapshot ID + var oldStripped StrippedEvents + if snapID != 0 { + oldStripped, err = a.strippedEventsForSnapshot(txn, snapID) + if err != nil { + return fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err) + } + } + newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev) + if err != nil { + return fmt.Errorf("failed to calculateNewSnapshot: %s", err) + } + replacesNID = replacedNID + memNIDs, otherNIDs := newStripped.NIDs() + newSnapshot := &SnapshotRow{ + RoomID: roomID, + MembershipEvents: memNIDs, + OtherEvents: otherNIDs, + } + if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil { + return fmt.Errorf("failed to insert new snapshot: %w", err) + } + snapID = newSnapshot.SnapshotID + } + if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil { + return err + } + } + + if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil { + return fmt.Errorf("HandleSpaceUpdates: %s", err) + } + + // the last fetched snapshot ID is the current one, so set it on the rooms table. + info := a.roomInfoDelta(roomID, newEvents) + if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil { + return fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err) + } + result.SnapshotID = snapID + result.TimelineNIDs = timelineNIDs + return nil + }) + return result, err +} + +// Calculate a snapshot for this room. There are 3 cases to consider. +// - 1: If there are no state events, this just returns the latest snapshot for this room, which may be 0 if there are no snapshots. +// If there are state events and there is: +// - 2: no create event in this slice: the state events provided are combined/rolled forward with the latest snapshot. +// - 3: a create event in this slice: the state slice alone is used to calculate the latest snapshot. +// +// These cases can occur for different real-world reasons: +// +// 1: Client is up-to-date and receives a single live message. No state events in `state`. Snapshot ID exists. +// 1: Client just joined a brand new room and all the room creation events fit inside the `timeline` section, so no state events in `state`. Snapshot ID doesn't exist. +// 2: Client poller is super slow, or we restarted a previously stopped poller. The classic 'gappy state' use case. +// 3: Client joins (or re-joins) a room. The state slice here is the most accurate representation of the room, as it will be the +// latest upstream state (which will have fixed state resets), vs rolling forward which will not. +func (a *Accumulator) calculateSnapshotForRoom(txn *sqlx.Tx, roomID string, state []json.RawMessage) (snapID int64, brandNew bool, err error) { + snapshotID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID) + if err != nil { + return 0, false, fmt.Errorf("calculateSnapshotForRoom: error fetching snapshot id for room %s: %s", roomID, err) + } + if len(state) == 0 { + return snapshotID, false, nil // Case 1, the snapshot may be 0. + } + + // Preprocess the state JSON + events := make([]Event, len(state)) + for i := range events { + events[i] = Event{ + JSON: state[i], + RoomID: roomID, + IsState: true, + } + } + events = filterAndEnsureFieldsSet(events) + if len(events) == 0 { + return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to insert events, all events were filtered out: %w", err) + } + + // Check for a create event + hasCreateEvent := false + allEventIDs := make([]string, 0, len(events)) + for _, ev := range events { + allEventIDs = append(allEventIDs, ev.ID) + if ev.Type == "m.room.create" && ev.StateKey == "" { + hasCreateEvent = true + } + } + + // Insert new events + insertedEventIDToNID, err := a.eventsTable.Insert(txn, events, false) + if err != nil { + return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to insert events: %w", err) + } + + if hasCreateEvent { // Case 3 + // If we have a create event then we want to ignore snapshot ID and use all the events in `state.` + // So just use allEventIDs, and verify all those events exist in the DB now. + strippedEvents, err := a.eventsTable.SelectStrippedEventsByIDs(txn, true, allEventIDs) + if err != nil { + return 0, false, fmt.Errorf("calculateSnapshotForRoom.SelectStrippedEventsByIDs: %w", err) + } + memNIDs, otherNIDs := strippedEvents.NIDs() + // Make a current snapshot + snapshot := &SnapshotRow{ + RoomID: roomID, + MembershipEvents: pq.Int64Array(memNIDs), + OtherEvents: pq.Int64Array(otherNIDs), + } + err = a.snapshotTable.Insert(txn, snapshot) + if err != nil { + return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to insert create snapshot: %w", err) + } + return snapshot.SnapshotID, true, nil + } + // no create event: Case 2 + // If we don't have a create event then we want to load the state at snapshot ID (which may be the empty set) + // and roll it forward with all the events in `state`. + + // load the old snapshot + var oldStripped StrippedEvents + if snapID != 0 { + oldStripped, err = a.strippedEventsForSnapshot(txn, snapID) + if err != nil { + return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to load stripped state events for snapshot %d: %s", snapID, err) + } + } + + genKey := func(ev Event) string { + return fmt.Sprintf("%s\x1f%s", ev.Type, ev.StateKey) + } + + insertedEventsByTypeStateKey := make(map[string]Event) + for _, ev := range events { + nid, ok := insertedEventIDToNID[ev.ID] + if ok { + ev.NID = nid + insertedEventsByTypeStateKey[genKey(ev)] = ev + } + } + + // we need to combine oldStripped and insertedEventsByTypeStateKey, preferring entries in insertedEventsByTypeStateKey when there is a tuple clash + newStripped := make(StrippedEvents, 0, len(oldStripped)) + for _, ev := range oldStripped { + key := genKey(ev) + newEvent, ok := insertedEventsByTypeStateKey[key] + if ok { + // use the newer event + newStripped = append(newStripped, newEvent) + } else { + // use the old event + newStripped = append(newStripped, ev) + } + } + + memNIDs, otherNIDs := newStripped.NIDs() + newSnapshot := &SnapshotRow{ + RoomID: roomID, + MembershipEvents: memNIDs, + OtherEvents: otherNIDs, + } + if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil { + return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to insert new snapshot: %w", err) + } + return newSnapshot.SnapshotID, true, nil +} + type InitialiseResult struct { // AddedEvents is true iff this call to Initialise added new state events to the DB. AddedEvents bool diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 102ca600..e91155dc 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -350,6 +350,108 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev } } +func (h *Handler) ProcessNewEvents(ctx context.Context, userID, deviceID, roomID string, timeline, state []json.RawMessage, prevBatch string) { + // Remember any transaction IDs that may be unique to this user + eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order + eventIDToTxnID := make(map[string]string, len(timeline)) // event_id -> txn_id + // Also remember events which were sent by this user but lack a transaction ID. + eventIDsLackingTxns := make([]string, 0, len(timeline)) + if len(timeline) > 0 { + for _, e := range timeline { + parsed := gjson.ParseBytes(e) + eventID := parsed.Get("event_id").Str + + if txnID := parsed.Get("unsigned.transaction_id"); txnID.Exists() { + eventIDsWithTxns = append(eventIDsWithTxns, eventID) + eventIDToTxnID[eventID] = txnID.Str + continue + } + + if sender := parsed.Get("sender"); sender.Str == userID { + eventIDsLackingTxns = append(eventIDsLackingTxns, eventID) + } + } + + if len(eventIDToTxnID) > 0 { + // persist the txn IDs + err := h.Store.TransactionsTable.Insert(userID, deviceID, eventIDToTxnID) + if err != nil { + logger.Err(err).Str("user", userID).Str("device", deviceID).Int("num_txns", len(eventIDToTxnID)).Msg("failed to persist txn IDs for user") + internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) + } + } + } + + result, err := h.Store.Accumulator.ProcessRoomEvents(userID, roomID, timeline, state, prevBatch) + if err != nil { + logger.Err(err).Str("room", roomID).Msg("V2: failed to process new room events") + return + } + if result.BrandNewSnapshot && result.SnapshotID > 0 { + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Initialise{ + RoomID: roomID, + SnapshotNID: result.SnapshotID, + }) + } + if len(result.TimelineNIDs) > 0 { + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{ + RoomID: roomID, + PrevBatch: prevBatch, + EventNIDs: result.TimelineNIDs, + }) + } + + if len(eventIDToTxnID) > 0 || len(eventIDsLackingTxns) > 0 { + // The call to h.Store.Accumulate above only tells us about new events' NIDS; + // for existing events we need to requery the database to fetch them. + // Rather than try to reuse work, keep things simple and just fetch NIDs for + // all events with txnIDs. + var nidsByIDs map[string]int64 + eventIDsToFetch := append(eventIDsWithTxns, eventIDsLackingTxns...) + err = sqlutil.WithTransaction(h.Store.DB, func(txn *sqlx.Tx) error { + nidsByIDs, err = h.Store.EventsTable.SelectNIDsByIDs(txn, eventIDsToFetch) + return err + }) + if err != nil { + logger.Err(err). + Int("timeline", len(timeline)). + Int("num_transaction_ids", len(eventIDsWithTxns)). + Int("num_missing_transaction_ids", len(eventIDsLackingTxns)). + Str("room", roomID). + Msg("V2: failed to fetch nids for event transaction_id handling") + internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) + return + } + + for eventID, nid := range nidsByIDs { + txnID, ok := eventIDToTxnID[eventID] + if ok { + h.PendingTxnIDs.SeenTxnID(eventID) + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{ + EventID: eventID, + RoomID: roomID, + UserID: userID, + DeviceID: deviceID, + TransactionID: txnID, + NID: nid, + }) + } else { + allClear, _ := h.PendingTxnIDs.MissingTxnID(eventID, userID, deviceID) + if allClear { + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{ + EventID: eventID, + RoomID: roomID, + UserID: userID, + DeviceID: deviceID, + TransactionID: "", + NID: nid, + }) + } + } + } + } +} + func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage { res, err := h.Store.Initialise(roomID, state) if err != nil { diff --git a/sync2/poller.go b/sync2/poller.go index d5791a66..8a3a8f2a 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -38,6 +38,7 @@ type V2DataReceiver interface { // Initialise the room, if it hasn't been already. This means the state section of the v2 response. // If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB. Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID? + ProcessNewEvents(ctx context.Context, userID, deviceID, roomID string, timeline, state []json.RawMessage, prevBatch string) // SetTyping indicates which users are typing. SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) // Sent when there is a new receipt @@ -297,6 +298,15 @@ func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json. wg.Wait() return } +func (h *PollerMap) ProcessNewEvents(ctx context.Context, userID, deviceID, roomID string, timeline, state []json.RawMessage, prevBatch string) { + var wg sync.WaitGroup + wg.Add(1) + h.executor <- func() { + h.callbacks.ProcessNewEvents(ctx, userID, deviceID, roomID, timeline, state, prevBatch) + wg.Done() + } + wg.Wait() +} func (h *PollerMap) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) { var wg sync.WaitGroup wg.Add(1) @@ -675,34 +685,17 @@ func (p *poller) parseGlobalAccountData(ctx context.Context, res *SyncResponse) func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) { ctx, task := internal.StartTask(ctx, "parseRoomsResponse") defer task.End() - stateCalls := 0 - timelineCalls := 0 typingCalls := 0 receiptCalls := 0 + stateEvents := 0 + timelineEvents := 0 for roomID, roomData := range res.Rooms.Join { - if len(roomData.State.Events) > 0 { - stateCalls++ - prependStateEvents := p.receiver.Initialise(ctx, roomID, roomData.State.Events) - if len(prependStateEvents) > 0 { - // The poller has just learned of these state events due to an - // incremental poller sync; we must have missed the opportunity to see - // these down /sync in a timeline. As a workaround, inject these into - // the timeline now so that future events are received under the - // correct room state. - const warnMsg = "parseRoomsResponse: prepending state events to timeline after gappy poll" - logger.Warn().Str("room_id", roomID).Int("prependStateEvents", len(prependStateEvents)).Msg(warnMsg) - hub := internal.GetSentryHubFromContextOrDefault(ctx) - hub.WithScope(func(scope *sentry.Scope) { - scope.SetContext(internal.SentryCtxKey, map[string]interface{}{ - "room_id": roomID, - "num_prepend_state_events": len(prependStateEvents), - }) - hub.CaptureMessage(warnMsg) - }) - p.trackGappyStateSize(len(prependStateEvents)) - roomData.Timeline.Events = append(prependStateEvents, roomData.Timeline.Events...) - } + stateEvents += len(roomData.State.Events) + timelineEvents += len(roomData.Timeline.Events) + if len(roomData.Timeline.Events) > 0 { + p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited) } + p.receiver.ProcessNewEvents(ctx, p.userID, p.deviceID, roomID, roomData.Timeline.Events, roomData.State.Events, roomData.Timeline.PrevBatch) // process typing/receipts before events so we seed the caches correctly for when we return the room for _, ephEvent := range roomData.Ephemeral.Events { ephEventType := gjson.GetBytes(ephEvent, "type").Str @@ -720,11 +713,6 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) { if len(roomData.AccountData.Events) > 0 { p.receiver.OnAccountData(ctx, p.userID, roomID, roomData.AccountData.Events) } - if len(roomData.Timeline.Events) > 0 { - timelineCalls++ - p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited) - p.receiver.Accumulate(ctx, p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events) - } // process unread counts AFTER events so global caches have been updated by the time this metadata is added. // Previously we did this BEFORE events so we atomically showed the event and the unread count in one go, but @@ -736,8 +724,8 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) { for roomID, roomData := range res.Rooms.Leave { if len(roomData.Timeline.Events) > 0 { p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited) - p.receiver.Accumulate(ctx, p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events) } + p.receiver.ProcessNewEvents(ctx, p.userID, p.deviceID, roomID, roomData.Timeline.Events, roomData.State.Events, roomData.Timeline.PrevBatch) // Pass the leave event directly to OnLeftRoom. We need to do this _in addition_ to calling Accumulate to handle // the case where a user rejects an invite (there will be no room state, but the user still expects to see the leave event). var leaveEvent json.RawMessage @@ -757,8 +745,8 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) { } p.totalReceipts += receiptCalls - p.totalStateCalls += stateCalls - p.totalTimelineCalls += timelineCalls + p.totalStateCalls += stateEvents + p.totalTimelineCalls += timelineEvents p.totalTyping += typingCalls p.totalInvites += len(res.Rooms.Invite) } diff --git a/sync2/poller_test.go b/sync2/poller_test.go index 9094e785..8de0f36c 100644 --- a/sync2/poller_test.go +++ b/sync2/poller_test.go @@ -623,6 +623,9 @@ func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID stri } func (s *mockDataReceiver) OnTerminated(ctx context.Context, pollerID PollerID) {} func (s *mockDataReceiver) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) { +} +func (s *mockDataReceiver) ProcessNewEvents(ctx context.Context, userID, deviceID, roomID string, timeline, state []json.RawMessage, prevBatch string) { + } func newMocks(doSyncV2 func(authHeader, since string) (*SyncResponse, int, error)) (*mockDataReceiver, *mockClient) {