diff --git a/state/accumulator.go b/state/accumulator.go index acab4555..f1bec0fd 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -347,54 +347,77 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, // We can track this by loading the current snapshot ID (after snapshot) then rolling forward // the timeline until we hit a state event, at which point we make a new snapshot but critically // do NOT assign the new state event in the snapshot so as to represent the state before the event. - snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID) + currentSnapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID) + if err != nil { + return 0, nil, err + } + currentState, err := a.fetchInitialSnapshot(txn, currentSnapID) + if err != nil { + return 0, nil, err + } + + // newEvents can occasionally be large. In this situation it is worth batching the + // updates and insert statements to avoid the RTT overhead of multiple queries + // stacking up. So: work out what we want to do first. + numState := countStateEvents(newEvents) + eventsTableUpdates := make([]beforeSnapshotUpdate, 0, len(newEvents)) + snapshotInserts := make([]SnapshotRow, 0, numState) + // We'll need to assign snapshot IDs ourselves, so ask the database to give us some. + snapshotIDs, err := a.snapshotTable.ReserveSnapshotIDs(txn, numState) if err != nil { return 0, nil, err } for _, ev := range newEvents { - var replacesNID int64 // the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not, // as this is the before snapshot ID. - beforeSnapID := snapID - + beforeSnapID := currentSnapID + var replacesNID int64 if ev.IsState { - // make a new snapshot and update the snapshot ID - var oldStripped StrippedEvents - if snapID != 0 { - oldStripped, err = a.strippedEventsForSnapshot(txn, snapID) - if err != nil { - return 0, nil, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err) - } - } - newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev) + // make a new snapshot + stateAfterEvent, replacedNID, err := a.calculateNewSnapshot(currentState, ev) if err != nil { return 0, nil, fmt.Errorf("failed to calculateNewSnapshot: %s", err) } replacesNID = replacedNID - memNIDs, otherNIDs := newStripped.NIDs() - newSnapshot := &SnapshotRow{ + // claim one of our reserved snap IDs for this snapshot + afterSnapID := snapshotIDs[0] + snapshotIDs = snapshotIDs[1:] + // decide to insert a new snapshot row. + memNIDs, otherNIDs := stateAfterEvent.NIDs() + snapshotInserts = append(snapshotInserts, SnapshotRow{ + SnapshotID: afterSnapID, RoomID: roomID, MembershipEvents: memNIDs, OtherEvents: otherNIDs, - } - if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil { - return 0, nil, fmt.Errorf("failed to insert new snapshot: %w", err) - } - snapID = newSnapshot.SnapshotID - } - if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil { - return 0, nil, err + }) + // This snapshot is now the current state of the room. + currentState = stateAfterEvent + currentSnapID = afterSnapID } + eventsTableUpdates = append(eventsTableUpdates, beforeSnapshotUpdate{ + NID: ev.NID, + BeforeStateSnapshotID: beforeSnapID, + ReplacesNID: replacesNID, + }) } + // Now that we've worked out what we want to do, let's actually do it. + if len(snapshotInserts) > 0 { + if err = a.snapshotTable.BulkInsert(txn, snapshotInserts); err != nil { + return 0, nil, fmt.Errorf("failed to insert new snapshot: %w", err) + } + } + if err = a.eventsTable.UpdateBeforeSnapshotIDs(txn, eventsTableUpdates); err != nil { + return 0, nil, fmt.Errorf("failed to update beforeSnapshotIDs: %w", err) + } if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil { return 0, nil, fmt.Errorf("HandleSpaceUpdates: %s", err) } // the last fetched snapshot ID is the current one info := a.roomInfoDelta(roomID, newEvents) - if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil { - return 0, nil, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err) + if err = a.roomsTable.Upsert(txn, info, currentSnapID, latestNID); err != nil { + return 0, nil, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", currentSnapID, err) } return numNew, timelineNIDs, nil } @@ -484,6 +507,26 @@ func (a *Accumulator) filterAndParseTimelineEvents(txn *sqlx.Tx, roomID string, return dedupedEvents[seenIndex+1:], nil } +func countStateEvents(events []Event) (count int64) { + for _, event := range events { + if event.IsState { + count++ + } + } + return count +} + +func (a *Accumulator) fetchInitialSnapshot(txn *sqlx.Tx, snapID int64) (StrippedEvents, error) { + if snapID == 0 { + return StrippedEvents{}, nil + } + stripped, err := a.strippedEventsForSnapshot(txn, snapID) + if err != nil { + return nil, fmt.Errorf("failed to load stripped state events for snapshot %d: %w", snapID, err) + } + return stripped, nil +} + // Delta returns a list of events of at most `limit` for the room not including `lastEventNID`. // Returns the latest NID of the last event (most recent) func (a *Accumulator) Delta(roomID string, lastEventNID int64, limit int) (eventsJSON []json.RawMessage, latest int64, err error) {