diff --git a/state/accumulator.go b/state/accumulator.go index d8005500..8966f1ca 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -399,8 +399,8 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch var latestNID int64 newEvents := make([]Event, 0, len(eventIDToNID)) - var redactTheseEventIDs []string - for _, ev := range dedupedEvents { + redactTheseEventIDs := make(map[string]*Event) + for i, ev := range dedupedEvents { nid, ok := eventIDToNID[ev.ID] if ok { ev.NID = int64(nid) @@ -423,7 +423,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch redactsEventID = parsedEv.Get("content.redacts").Str } if redactsEventID != "" { - redactTheseEventIDs = append(redactTheseEventIDs, redactsEventID) + redactTheseEventIDs[redactsEventID] = &dedupedEvents[i] } } newEvents = append(newEvents, ev) diff --git a/state/event_table.go b/state/event_table.go index a793c413..035a4333 100644 --- a/state/event_table.go +++ b/state/event_table.go @@ -338,7 +338,11 @@ func (t *EventTable) LatestEventNIDInRooms(txn *sqlx.Tx, roomIDs []string, highe return } -func (t *EventTable) Redact(txn *sqlx.Tx, roomVer string, eventIDs []string) error { +func (t *EventTable) Redact(txn *sqlx.Tx, roomVer string, redacteeEventIDToRedactEvent map[string]*Event) error { + eventIDs := make([]string, 0, len(redacteeEventIDToRedactEvent)) + for e := range redacteeEventIDToRedactEvent { + eventIDs = append(eventIDs, e) + } // verifyAll=false so if we are asked to redact an event we don't have we don't fall over. eventsToRedact, err := t.SelectByIDs(txn, false, eventIDs) if err != nil { @@ -357,6 +361,13 @@ func (t *EventTable) Redact(txn *sqlx.Tx, roomVer string, eventIDs []string) err if err != nil { return fmt.Errorf("RedactEventJSON[%s]: %w", eventsToRedact[i].ID, err) } + // also set unsigned.redacted_because as EX relies on it + eventsToRedact[i].JSON, err = sjson.SetBytes( + eventsToRedact[i].JSON, "unsigned.redacted_because", json.RawMessage(redacteeEventIDToRedactEvent[eventsToRedact[i].ID].JSON), + ) + if err != nil { + return fmt.Errorf("RedactEventJSON[%s]: setting redacted_because %w", eventsToRedact[i].ID, err) + } _, err = txn.Exec(`UPDATE syncv3_events SET event=$1 WHERE event_id=$2`, eventsToRedact[i].JSON, eventsToRedact[i].ID) if err != nil { return fmt.Errorf("cannot update event %s: %w", eventsToRedact[i].ID, err) diff --git a/state/event_table_test.go b/state/event_table_test.go index 7d6fb2c8..f6f0a7bb 100644 --- a/state/event_table_test.go +++ b/state/event_table_test.go @@ -1063,7 +1063,17 @@ func TestEventTableRedact(t *testing.T) { JSON: j, }, }, false) - assertNoError(t, table.Redact(txn, tc.roomVer, []string{eventID})) + assertNoError(t, table.Redact(txn, tc.roomVer, map[string]*Event{ + eventID: &Event{ + JSON: json.RawMessage(`{ + "type": "m.room.redaction", + "event_id": "$unimportant", + "content": { + "redacts": "` + eventID + `" + } + }`), + }, + })) gots, err := table.SelectByIDs(txn, true, []string{eventID}) assertNoError(t, err) if len(gots) != 1 { @@ -1088,5 +1098,32 @@ func TestEventTableRedactMissingOK(t *testing.T) { t.Fatalf("failed to start txn: %s", err) } defer txn.Rollback() - assertNoError(t, table.Redact(txn, "2", []string{"$unknown", "$event", "$ids"})) + assertNoError(t, table.Redact(txn, "2", map[string]*Event{ + "$unknown": { + JSON: json.RawMessage(`{ + "type": "m.room.redaction", + "event_id": "$unimportant", + "content": { + "redacts": "$unknown" + } + }`), + }, + "$event": { + JSON: json.RawMessage(`{ + "type": "m.room.redaction", + "event_id": "$unimportant", + "content": { + "redacts": "$event" + } + }`), + }, + "$ids": { + JSON: json.RawMessage(`{ + "type": "m.room.redaction", + "event_id": "$unimportant", + "content": { + "redacts": "$ids" + } + }`), + }})) } diff --git a/tests-e2e/main_test.go b/tests-e2e/main_test.go index 2ea63800..e6b42491 100644 --- a/tests-e2e/main_test.go +++ b/tests-e2e/main_test.go @@ -77,6 +77,15 @@ func eventsEqual(wantList []Event, gotList []json.RawMessage) error { if want.Sender != "" && want.Sender != got.Sender { return fmt.Errorf("event %d Sender mismatch: got %v want %v", i, got.Sender, want.Sender) } + // loop each key on unsigned as unsigned also includes "age" which is unpredictable so cannot DeepEqual + if want.Unsigned != nil { + for k, v := range want.Unsigned { + got := got.Unsigned[k] + if !reflect.DeepEqual(got, v) { + return fmt.Errorf("event %d Unsigned.%s mismatch: got %v want %v", i, k, got, v) + } + } + } } return nil } @@ -218,3 +227,10 @@ func registerNamedUser(t *testing.T, localpartPrefix string) *CSAPI { func ptr(s string) *string { return &s } + +func assertEqual(t *testing.T, msg string, got, want interface{}) { + t.Helper() + if !reflect.DeepEqual(got, want) { + t.Fatalf("%s: got %v want %v", msg, got, want) + } +} diff --git a/tests-e2e/redaction_test.go b/tests-e2e/redaction_test.go index 090dbf33..54af73ca 100644 --- a/tests-e2e/redaction_test.go +++ b/tests-e2e/redaction_test.go @@ -5,6 +5,7 @@ import ( "github.com/matrix-org/sliding-sync/sync3" "github.com/matrix-org/sliding-sync/testutils/m" + "github.com/tidwall/gjson" ) func TestRedactionsAreRedactedWherePossible(t *testing.T) { @@ -59,5 +60,18 @@ func TestRedactionsAreRedactedWherePossible(t *testing.T) { {ID: redactionEventID}, })}, })) + // introspect the unsigned key a bit more, we don't know all the fields so can't use a matcher + gotEvent := gjson.ParseBytes(res.Rooms[room].Timeline[len(res.Rooms[room].Timeline)-2]) + redactedBecause := gotEvent.Get("unsigned.redacted_because") + if !redactedBecause.Exists() { + t.Fatalf("unsigned.redacted_because must exist, but it doesn't. Got: %v", gotEvent.Raw) + } + // assert basic fields + assertEqual(t, "event_id mismatch", redactedBecause.Get("event_id").Str, redactionEventID) + assertEqual(t, "sender mismatch", redactedBecause.Get("sender").Str, alice.UserID) + assertEqual(t, "type mismatch", redactedBecause.Get("type").Str, "m.room.redaction") + if !redactedBecause.Get("content").Exists() { + t.Fatalf("unsigned.redacted_because.content must exist, but it doesn't. Got: %v", gotEvent.Raw) + } }