Skip to content

Commit

Permalink
Re-write state accumulator logic
Browse files Browse the repository at this point in the history
As per #211 - this combines Initialise and Accumulate into a single
ProcessRoomEvents functions which can do snapshots / timelines.

Implements the "brand new snapshot on create event" logic, which
currently does not correctly invalidate caches.

E2E tests pass, but integ tests are broken.
  • Loading branch information
kegsay committed Aug 11, 2023
1 parent f73c8e4 commit d6ebcde
Show file tree
Hide file tree
Showing 4 changed files with 402 additions and 32 deletions.
277 changes: 277 additions & 0 deletions state/accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,283 @@ func (a *Accumulator) roomInfoDelta(roomID string, events []Event) RoomInfo {
}
}

type ProcessRoomEventResult struct {
BrandNewSnapshot bool
SnapshotID int64
TimelineNIDs []int64
}

// ProcessRoomEvents creates state snapshots and inserts timeline sections into the DB. Returns:
// - The latest state snapshot NID
// - The new timeline events
// - Whether a brand new snapshot was created
// Downstream components need to refetch room state based on the state snapshot NID if a brand new snapshot was created.
// Otherwise, it can just consume the timeline events and roll forward state.
func (a *Accumulator) ProcessRoomEvents(userID, roomID string, timeline, state []json.RawMessage, prevBatch string) (*ProcessRoomEventResult, error) {
result := &ProcessRoomEventResult{}
// If there are no timeline events, do nothing. Note: even if there _are_ state events, such that we would
// generate a new state snapshot, we won't then subsequently _use_ this snapshot ID for any timeline event,
// meaning we'd end up with a dangling snapshot. This is why we don't care how long the state slice is.
// This is a weird case though so we will return an error/sentry it.
if len(timeline) == 0 {
if len(state) > 0 {
err := fmt.Errorf("processRoomEvents: %v received 0 timeline and %v state events for room %v which doesn't make sense, ignoring", userID, roomID, len(state))
sentry.CaptureException(err)
return result, err
}
return result, nil
}

err := sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error {
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
// we expect:
// - there to be no duplicate events
// - if there are new events, they are always new.
// Both of these assumptions can be false for different reasons
dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch)
if err != nil {
return fmt.Errorf("filterAndParseTimelineEvents: %w", err)
}
if len(dedupedEvents) == 0 {
return nil // nothing to do
}

// we know there are some events to insert which are new. We need to work out what snapshot
// to base these changes on.

snapID, brandNew, err := a.calculateSnapshotForRoom(txn, roomID, state)
if err != nil {
return err
}
result.BrandNewSnapshot = brandNew

// if we have just got a leave event for the polling user, and there is no snapshot for this room already, then
// we do NOT want to add this event to the events table, nor do we want to make a room snapshot. This is because
// this leave event is an invite rejection, rather than a normal event. Invite rejections cannot be processed in
// a normal way because we lack room state (no create event, PLs, etc). If we were to process the invite rejection,
// the room state would just be a single event: this leave event, which is wrong.
if len(dedupedEvents) == 1 &&
dedupedEvents[0].Type == "m.room.member" &&
(dedupedEvents[0].Membership == "leave" || dedupedEvents[0].Membership == "_leave") &&
dedupedEvents[0].StateKey == userID &&
snapID == 0 {
logger.Info().Str("event_id", dedupedEvents[0].ID).Str("room_id", roomID).Str("user_id", userID).Err(err).Msg(
"Accumulator: skipping processing of leave event, as no snapshot exists",
)
return nil
}

// Insert the event JSON
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return fmt.Errorf("failed to insert events: %w", err)
}
if len(eventIDToNID) == 0 {
// nothing to do, we already know about these events
return nil
}

// Set the NID values for each event now they are inserted
var timelineNIDs []int64
var latestNID int64
newEvents := make([]Event, 0, len(eventIDToNID))
for _, ev := range dedupedEvents {
nid, ok := eventIDToNID[ev.ID]
if !ok {
continue // we must've seen this event which is why it's not in the map
}
ev.NID = int64(nid)
if gjson.GetBytes(ev.JSON, "state_key").Exists() {
// XXX: reusing this to mean "it's a state event" as well as "it's part of the state v2 response"
// its important that we don't insert 'ev' at this point as this should be False in the DB.
ev.IsState = true
}
// assign the highest nid value to the latest nid.
// we'll return this to the caller so they can stay in-sync
if ev.NID > latestNID {
latestNID = ev.NID
}
newEvents = append(newEvents, ev)
timelineNIDs = append(timelineNIDs, ev.NID)
}

// Begin rolling forward state and calculating snapshots for each event
for _, ev := range newEvents {
var replacesNID int64
// the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not,
// as this is the before snapshot ID.
beforeSnapID := snapID

if ev.IsState {
// make a new snapshot and update the snapshot ID
var oldStripped StrippedEvents
if snapID != 0 {
oldStripped, err = a.strippedEventsForSnapshot(txn, snapID)
if err != nil {
return fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
}
}
newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev)
if err != nil {
return fmt.Errorf("failed to calculateNewSnapshot: %s", err)
}
replacesNID = replacedNID
memNIDs, otherNIDs := newStripped.NIDs()
newSnapshot := &SnapshotRow{
RoomID: roomID,
MembershipEvents: memNIDs,
OtherEvents: otherNIDs,
}
if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil {
return fmt.Errorf("failed to insert new snapshot: %w", err)
}
snapID = newSnapshot.SnapshotID
}
if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil {
return err
}
}

if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil {
return fmt.Errorf("HandleSpaceUpdates: %s", err)
}

// the last fetched snapshot ID is the current one, so set it on the rooms table.
info := a.roomInfoDelta(roomID, newEvents)
if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil {
return fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
}
result.SnapshotID = snapID
result.TimelineNIDs = timelineNIDs
return nil
})
return result, err
}

// Calculate a snapshot for this room. There are 3 cases to consider.
// - 1: If there are no state events, this just returns the latest snapshot for this room, which may be 0 if there are no snapshots.
// If there are state events and there is:
// - 2: no create event in this slice: the state events provided are combined/rolled forward with the latest snapshot.
// - 3: a create event in this slice: the state slice alone is used to calculate the latest snapshot.
//
// These cases can occur for different real-world reasons:
//
// 1: Client is up-to-date and receives a single live message. No state events in `state`. Snapshot ID exists.
// 1: Client just joined a brand new room and all the room creation events fit inside the `timeline` section, so no state events in `state`. Snapshot ID doesn't exist.
// 2: Client poller is super slow, or we restarted a previously stopped poller. The classic 'gappy state' use case.
// 3: Client joins (or re-joins) a room. The state slice here is the most accurate representation of the room, as it will be the
// latest upstream state (which will have fixed state resets), vs rolling forward which will not.
func (a *Accumulator) calculateSnapshotForRoom(txn *sqlx.Tx, roomID string, state []json.RawMessage) (snapID int64, brandNew bool, err error) {
snapshotID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, false, fmt.Errorf("calculateSnapshotForRoom: error fetching snapshot id for room %s: %s", roomID, err)
}
if len(state) == 0 {
return snapshotID, false, nil // Case 1, the snapshot may be 0.
}

// Preprocess the state JSON
events := make([]Event, len(state))
for i := range events {
events[i] = Event{
JSON: state[i],
RoomID: roomID,
IsState: true,
}
}
events = filterAndEnsureFieldsSet(events)
if len(events) == 0 {
return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to insert events, all events were filtered out: %w", err)
}

// Check for a create event
hasCreateEvent := false
allEventIDs := make([]string, 0, len(events))
for _, ev := range events {
allEventIDs = append(allEventIDs, ev.ID)
if ev.Type == "m.room.create" && ev.StateKey == "" {
hasCreateEvent = true
}
}

// Insert new events
insertedEventIDToNID, err := a.eventsTable.Insert(txn, events, false)
if err != nil {
return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to insert events: %w", err)
}

if hasCreateEvent { // Case 3
// If we have a create event then we want to ignore snapshot ID and use all the events in `state.`
// So just use allEventIDs, and verify all those events exist in the DB now.
strippedEvents, err := a.eventsTable.SelectStrippedEventsByIDs(txn, true, allEventIDs)
if err != nil {
return 0, false, fmt.Errorf("calculateSnapshotForRoom.SelectStrippedEventsByIDs: %w", err)
}
memNIDs, otherNIDs := strippedEvents.NIDs()
// Make a current snapshot
snapshot := &SnapshotRow{
RoomID: roomID,
MembershipEvents: pq.Int64Array(memNIDs),
OtherEvents: pq.Int64Array(otherNIDs),
}
err = a.snapshotTable.Insert(txn, snapshot)
if err != nil {
return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to insert create snapshot: %w", err)
}
return snapshot.SnapshotID, true, nil
}
// no create event: Case 2
// If we don't have a create event then we want to load the state at snapshot ID (which may be the empty set)
// and roll it forward with all the events in `state`.

// load the old snapshot
var oldStripped StrippedEvents
if snapID != 0 {
oldStripped, err = a.strippedEventsForSnapshot(txn, snapID)
if err != nil {
return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to load stripped state events for snapshot %d: %s", snapID, err)
}
}

genKey := func(ev Event) string {
return fmt.Sprintf("%s\x1f%s", ev.Type, ev.StateKey)
}

insertedEventsByTypeStateKey := make(map[string]Event)
for _, ev := range events {
nid, ok := insertedEventIDToNID[ev.ID]
if ok {
ev.NID = nid
insertedEventsByTypeStateKey[genKey(ev)] = ev
}
}

// we need to combine oldStripped and insertedEventsByTypeStateKey, preferring entries in insertedEventsByTypeStateKey when there is a tuple clash
newStripped := make(StrippedEvents, 0, len(oldStripped))
for _, ev := range oldStripped {
key := genKey(ev)
newEvent, ok := insertedEventsByTypeStateKey[key]
if ok {
// use the newer event
newStripped = append(newStripped, newEvent)
} else {
// use the old event
newStripped = append(newStripped, ev)
}
}

memNIDs, otherNIDs := newStripped.NIDs()
newSnapshot := &SnapshotRow{
RoomID: roomID,
MembershipEvents: memNIDs,
OtherEvents: otherNIDs,
}
if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil {
return 0, false, fmt.Errorf("calculateSnapshotForRoom: failed to insert new snapshot: %w", err)
}
return newSnapshot.SnapshotID, true, nil
}

type InitialiseResult struct {
// AddedEvents is true iff this call to Initialise added new state events to the DB.
AddedEvents bool
Expand Down
102 changes: 102 additions & 0 deletions sync2/handler2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,108 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
}
}

func (h *Handler) ProcessNewEvents(ctx context.Context, userID, deviceID, roomID string, timeline, state []json.RawMessage, prevBatch string) {
// Remember any transaction IDs that may be unique to this user
eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order
eventIDToTxnID := make(map[string]string, len(timeline)) // event_id -> txn_id
// Also remember events which were sent by this user but lack a transaction ID.
eventIDsLackingTxns := make([]string, 0, len(timeline))
if len(timeline) > 0 {
for _, e := range timeline {
parsed := gjson.ParseBytes(e)
eventID := parsed.Get("event_id").Str

if txnID := parsed.Get("unsigned.transaction_id"); txnID.Exists() {
eventIDsWithTxns = append(eventIDsWithTxns, eventID)
eventIDToTxnID[eventID] = txnID.Str
continue
}

if sender := parsed.Get("sender"); sender.Str == userID {
eventIDsLackingTxns = append(eventIDsLackingTxns, eventID)
}
}

if len(eventIDToTxnID) > 0 {
// persist the txn IDs
err := h.Store.TransactionsTable.Insert(userID, deviceID, eventIDToTxnID)
if err != nil {
logger.Err(err).Str("user", userID).Str("device", deviceID).Int("num_txns", len(eventIDToTxnID)).Msg("failed to persist txn IDs for user")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
}
}
}

result, err := h.Store.Accumulator.ProcessRoomEvents(userID, roomID, timeline, state, prevBatch)
if err != nil {
logger.Err(err).Str("room", roomID).Msg("V2: failed to process new room events")
return
}
if result.BrandNewSnapshot && result.SnapshotID > 0 {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Initialise{
RoomID: roomID,
SnapshotNID: result.SnapshotID,
})
}
if len(result.TimelineNIDs) > 0 {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
RoomID: roomID,
PrevBatch: prevBatch,
EventNIDs: result.TimelineNIDs,
})
}

if len(eventIDToTxnID) > 0 || len(eventIDsLackingTxns) > 0 {
// The call to h.Store.Accumulate above only tells us about new events' NIDS;
// for existing events we need to requery the database to fetch them.
// Rather than try to reuse work, keep things simple and just fetch NIDs for
// all events with txnIDs.
var nidsByIDs map[string]int64
eventIDsToFetch := append(eventIDsWithTxns, eventIDsLackingTxns...)
err = sqlutil.WithTransaction(h.Store.DB, func(txn *sqlx.Tx) error {
nidsByIDs, err = h.Store.EventsTable.SelectNIDsByIDs(txn, eventIDsToFetch)
return err
})
if err != nil {
logger.Err(err).
Int("timeline", len(timeline)).
Int("num_transaction_ids", len(eventIDsWithTxns)).
Int("num_missing_transaction_ids", len(eventIDsLackingTxns)).
Str("room", roomID).
Msg("V2: failed to fetch nids for event transaction_id handling")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}

for eventID, nid := range nidsByIDs {
txnID, ok := eventIDToTxnID[eventID]
if ok {
h.PendingTxnIDs.SeenTxnID(eventID)
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{
EventID: eventID,
RoomID: roomID,
UserID: userID,
DeviceID: deviceID,
TransactionID: txnID,
NID: nid,
})
} else {
allClear, _ := h.PendingTxnIDs.MissingTxnID(eventID, userID, deviceID)
if allClear {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{
EventID: eventID,
RoomID: roomID,
UserID: userID,
DeviceID: deviceID,
TransactionID: "",
NID: nid,
})
}
}
}
}
}

func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage {
res, err := h.Store.Initialise(roomID, state)
if err != nil {
Expand Down
Loading

0 comments on commit d6ebcde

Please sign in to comment.