Skip to content

Commit

Permalink
feature: add txnids to events
Browse files Browse the repository at this point in the history
Clients rely on transaction IDs coming down their /sync streams so they
can pair up an incoming event with an event they just sent but have not
yet got the event ID for.

The proxy has not historically handled this because of the shared work
model of operation, where we store exactly 1 copy of the event in the
database and no more. This means if Alice and Bob are running in the
same proxy, then Alice sends a message, Bob's /sync stream may get the
event first and that will NOT contain the `transaction_id`. This then
gets written into the database. Later when Alice /syncs, she will not
get the `transaction_id` for her event which she sent.

This commit fixes this by having a TTL cache which maps (user, event)
-> txn_id. Transaction IDs are inherently ephemeral, so keeping the
last 5 minutes worth of txn IDs in-memory is an easy solution which
will be good enough for the proxy. Actual server implementations of
sliding sync will be able to trivially deal with this behaviour natively.
  • Loading branch information
kegsay committed Mar 28, 2022
1 parent 53480c1 commit 2920191
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 14 deletions.
31 changes: 29 additions & 2 deletions sync2/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ type E2EEFetcher interface {
LatestE2EEData(deviceID string) (otkCounts map[string]int, changed, left []string)
}

type TransactionIDFetcher interface {
TransactionIDForEvent(userID, eventID string) (txnID string)
}

// PollerMap is a map of device ID to Poller
type PollerMap struct {
v2Client Client
Expand All @@ -42,6 +46,7 @@ type PollerMap struct {
Pollers map[string]*Poller // device_id -> poller
executor chan func()
executorRunning bool
txnCache *TransactionIDCache
}

// NewPollerMap makes a new PollerMap. Guarantees that the V2DataReceiver will be called on the same
Expand Down Expand Up @@ -72,9 +77,15 @@ func NewPollerMap(v2Client Client, callbacks V2DataReceiver) *PollerMap {
pollerMu: &sync.Mutex{},
Pollers: make(map[string]*Poller),
executor: make(chan func(), 0),
txnCache: NewTransactionIDCache(),
}
}

// TransactionIDForEvent returns the transaction ID for this event for this user, if one exists.
func (h *PollerMap) TransactionIDForEvent(userID, eventID string) string {
return h.txnCache.Get(userID, eventID)
}

// LatestE2EEData pulls the latest device_lists and device_one_time_keys_count values from the poller.
// These bits of data are ephemeral and do not need to be persisted.
func (h *PollerMap) LatestE2EEData(deviceID string) (otkCounts map[string]int, changed, left []string) {
Expand Down Expand Up @@ -114,7 +125,7 @@ func (h *PollerMap) EnsurePolling(authHeader, userID, deviceID, v2since string,
return
}
// replace the poller
poller = NewPoller(userID, authHeader, deviceID, h.v2Client, h, logger)
poller = NewPoller(userID, authHeader, deviceID, h.v2Client, h, h.txnCache, logger)
go poller.Poll(v2since)
h.Pollers[deviceID] = poller

Expand Down Expand Up @@ -210,6 +221,9 @@ type Poller struct {
receiver V2DataReceiver
logger zerolog.Logger

// remember txn ids
txnCache *TransactionIDCache

// E2EE fields
e2eeMu *sync.Mutex
otkCounts map[string]int
Expand All @@ -220,7 +234,7 @@ type Poller struct {
wg *sync.WaitGroup
}

func NewPoller(userID, authHeader, deviceID string, client Client, receiver V2DataReceiver, logger zerolog.Logger) *Poller {
func NewPoller(userID, authHeader, deviceID string, client Client, receiver V2DataReceiver, txnCache *TransactionIDCache, logger zerolog.Logger) *Poller {
var wg sync.WaitGroup
wg.Add(1)
return &Poller{
Expand All @@ -234,6 +248,7 @@ func NewPoller(userID, authHeader, deviceID string, client Client, receiver V2Da
e2eeMu: &sync.Mutex{},
deviceListChanges: make(map[string]string),
wg: &wg,
txnCache: txnCache,
}
}

Expand Down Expand Up @@ -342,6 +357,17 @@ func (p *Poller) parseGlobalAccountData(res *SyncResponse) {
p.receiver.OnAccountData(p.userID, AccountDataGlobalRoom, res.AccountData.Events)
}

func (p *Poller) updateTxnIDCache(timeline []json.RawMessage) {
for _, e := range timeline {
txnID := gjson.GetBytes(e, "unsigned.transaction_id")
if !txnID.Exists() {
continue
}
eventID := gjson.GetBytes(e, "event_id").Str
p.txnCache.Store(p.userID, eventID, txnID.Str)
}
}

func (p *Poller) parseRoomsResponse(res *SyncResponse) {
stateCalls := 0
timelineCalls := 0
Expand All @@ -363,6 +389,7 @@ func (p *Poller) parseRoomsResponse(res *SyncResponse) {
}
if len(roomData.Timeline.Events) > 0 {
timelineCalls++
p.updateTxnIDCache(roomData.Timeline.Events)
p.receiver.Accumulate(roomID, roomData.Timeline.Events)
}
for _, ephEvent := range roomData.Ephemeral.Events {
Expand Down
8 changes: 5 additions & 3 deletions sync2/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/rs/zerolog"
)

var txnIDCache = NewTransactionIDCache()

// Check that a call to Poll starts polling and accumulating, and terminates on 401s.
func TestPollerPollFromNothing(t *testing.T) {
nextSince := "next"
Expand Down Expand Up @@ -44,7 +46,7 @@ func TestPollerPollFromNothing(t *testing.T) {
})
var wg sync.WaitGroup
wg.Add(1)
poller := NewPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr))
poller := NewPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, txnIDCache, zerolog.New(os.Stderr))
go func() {
defer wg.Done()
poller.Poll("")
Expand Down Expand Up @@ -127,7 +129,7 @@ func TestPollerPollFromExisting(t *testing.T) {
})
var wg sync.WaitGroup
wg.Add(1)
poller := NewPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr))
poller := NewPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, txnIDCache, zerolog.New(os.Stderr))
go func() {
defer wg.Done()
poller.Poll(since)
Expand Down Expand Up @@ -203,7 +205,7 @@ func TestPollerBackoff(t *testing.T) {
}
var wg sync.WaitGroup
wg.Add(1)
poller := NewPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr))
poller := NewPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, txnIDCache, zerolog.New(os.Stderr))
go func() {
defer wg.Done()
poller.Poll("some_since_value")
Expand Down
38 changes: 38 additions & 0 deletions sync2/txnid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package sync2

import (
"time"

"github.com/ReneKroon/ttlcache/v2"
)

type TransactionIDCache struct {
cache *ttlcache.Cache
}

func NewTransactionIDCache() *TransactionIDCache {
c := ttlcache.NewCache()
c.SetTTL(5 * time.Minute) // keep transaction IDs for 5 minutes before forgetting about them
c.SkipTTLExtensionOnHit(true) // we don't care how many times they ask for the item, 5min is the limit.
return &TransactionIDCache{
cache: c,
}
}

// Store a new transaction ID received via v2 /sync
func (c *TransactionIDCache) Store(userID, eventID, txnID string) {
c.cache.Set(cacheKey(userID, eventID), txnID)
}

// Get a transaction ID previously stored.
func (c *TransactionIDCache) Get(userID, eventID string) string {
val, _ := c.cache.Get(cacheKey(userID, eventID))
if val != nil {
return val.(string)
}
return ""
}

func cacheKey(userID, eventID string) string {
return userID + " " + eventID
}
55 changes: 55 additions & 0 deletions sync2/txnid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package sync2

import "testing"

func TestTransactionIDCache(t *testing.T) {
alice := "@alice:localhost"
bob := "@bob:localhost"
eventA := "$a:localhost"
eventB := "$b:localhost"
eventC := "$c:localhost"
txn1 := "1"
txn2 := "2"
cache := NewTransactionIDCache()
cache.Store(alice, eventA, txn1)
cache.Store(bob, eventB, txn1) // different users can use same txn ID
cache.Store(alice, eventC, txn2)

testCases := []struct {
eventID string
userID string
want string
}{
{
eventID: eventA,
userID: alice,
want: txn1,
},
{
eventID: eventB,
userID: bob,
want: txn1,
},
{
eventID: eventC,
userID: alice,
want: txn2,
},
{
eventID: "$invalid",
userID: alice,
want: "",
},
{
eventID: eventA,
userID: "@invalid",
want: "",
},
}
for _, tc := range testCases {
txnID := cache.Get(tc.userID, tc.eventID)
if txnID != tc.want {
t.Errorf("%+v: got %v want %v", tc, txnID, tc.want)
}
}
}
27 changes: 26 additions & 1 deletion sync3/caches/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (

"github.com/matrix-org/sync-v3/internal"
"github.com/matrix-org/sync-v3/state"
"github.com/matrix-org/sync-v3/sync2"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)

type UserRoomData struct {
Expand Down Expand Up @@ -38,9 +40,10 @@ type UserCache struct {
id int
store *state.Storage
globalCache *GlobalCache
txnIDs sync2.TransactionIDFetcher
}

func NewUserCache(userID string, globalCache *GlobalCache, store *state.Storage) *UserCache {
func NewUserCache(userID string, globalCache *GlobalCache, store *state.Storage, txnIDs sync2.TransactionIDFetcher) *UserCache {
uc := &UserCache{
UserID: userID,
roomToDataMu: &sync.RWMutex{},
Expand All @@ -49,6 +52,7 @@ func NewUserCache(userID string, globalCache *GlobalCache, store *state.Storage)
listenersMu: &sync.Mutex{},
store: store,
globalCache: globalCache,
txnIDs: txnIDs,
}
return uc
}
Expand Down Expand Up @@ -181,6 +185,27 @@ func (c *UserCache) newRoomUpdate(roomID string) RoomUpdate {
}
}

// AnnotateWithTransactionIDs should be called just prior to returning events to the client. This
// will modify the events to insert the correct transaction IDs if needed. This is required because
// events are globally scoped, so if Alice sends a message, Bob might receive it first on his v2 loop
// which would cause the transaction ID to be missing from the event. Instead, we always look for txn
// IDs in the v2 poller, and then set them appropriately at request time.
func (c *UserCache) AnnotateWithTransactionIDs(events []json.RawMessage) []json.RawMessage {
for i := range events {
eventID := gjson.GetBytes(events[i], "event_id")
txnID := c.txnIDs.TransactionIDForEvent(c.UserID, eventID.Str)
if txnID != "" {
newJSON, err := sjson.SetBytes(events[i], "unsigned.transaction_id", txnID)
if err != nil {
logger.Err(err).Str("user", c.UserID).Msg("AnnotateWithTransactionIDs: sjson failed")
} else {
events[i] = newJSON
}
}
}
return events
}

// =================================================
// Listener functions called by v2 pollers are below
// =================================================
Expand Down
6 changes: 3 additions & 3 deletions sync3/handler/connstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,9 @@ func (s *ConnState) getDeltaRoomData(roomID string, event json.RawMessage) *sync
HighlightCount: int64(userRoomData.HighlightCount),
}
if event != nil {
room.Timeline = []json.RawMessage{
room.Timeline = s.userCache.AnnotateWithTransactionIDs([]json.RawMessage{
event,
}
})
}
return room
}
Expand All @@ -342,7 +342,7 @@ func (s *ConnState) getInitialRoomData(listIndex int, timelineLimit int, roomIDs
Name: internal.CalculateRoomName(metadata, 5), // TODO: customisable?
NotificationCount: int64(userRoomData.NotificationCount),
HighlightCount: int64(userRoomData.HighlightCount),
Timeline: userRoomData.Timeline,
Timeline: s.userCache.AnnotateWithTransactionIDs(userRoomData.Timeline),
RequiredState: s.globalCache.LoadRoomState(roomID, s.loadPosition, s.muxedReq.GetRequiredState(listIndex, roomID)),
Initial: true,
}
Expand Down
14 changes: 10 additions & 4 deletions sync3/handler/connstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ func (t *NopJoinTracker) IsUserJoined(userID, roomID string) bool {
return true
}

type NopTransactionFetcher struct{}

func (t *NopTransactionFetcher) TransactionIDForEvent(userID, eventID string) (txnID string) {
return ""
}

func newRoomMetadata(roomID string, lastMsgTimestamp gomatrixserverlib.Timestamp) internal.RoomMetadata {
return internal.RoomMetadata{
RoomID: roomID,
Expand Down Expand Up @@ -87,7 +93,7 @@ func TestConnStateInitial(t *testing.T) {
&roomA, &roomB, &roomC,
}, nil
}
userCache := caches.NewUserCache(userID, globalCache, nil)
userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{})
dispatcher.Register(userCache.UserID, userCache)
dispatcher.Register(sync3.DispatcherAllUsers, globalCache)
userCache.LazyRoomDataOverride = func(loadPos int64, roomIDs []string, maxTimelineEvents int) map[string]caches.UserRoomData {
Expand Down Expand Up @@ -244,7 +250,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms []*internal.RoomMetadata, err error) {
return 1, rooms, nil
}
userCache := caches.NewUserCache(userID, globalCache, nil)
userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{})
userCache.LazyRoomDataOverride = mockLazyRoomOverride
dispatcher.Register(userCache.UserID, userCache)
dispatcher.Register(sync3.DispatcherAllUsers, globalCache)
Expand Down Expand Up @@ -426,7 +432,7 @@ func TestBumpToOutsideRange(t *testing.T) {
&roomA, &roomB, &roomC, &roomD,
}, nil
}
userCache := caches.NewUserCache(userID, globalCache, nil)
userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{})
userCache.LazyRoomDataOverride = mockLazyRoomOverride
dispatcher.Register(userCache.UserID, userCache)
dispatcher.Register(sync3.DispatcherAllUsers, globalCache)
Expand Down Expand Up @@ -524,7 +530,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
&roomA, &roomB, &roomC, &roomD,
}, nil
}
userCache := caches.NewUserCache(userID, globalCache, nil)
userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{})
userCache.LazyRoomDataOverride = func(loadPos int64, roomIDs []string, maxTimelineEvents int) map[string]caches.UserRoomData {
result := make(map[string]caches.UserRoomData)
for _, roomID := range roomIDs {
Expand Down
2 changes: 1 addition & 1 deletion sync3/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (h *SyncLiveHandler) userCache(userID string) (*caches.UserCache, error) {
if ok {
return c.(*caches.UserCache), nil
}
uc := caches.NewUserCache(userID, h.GlobalCache, h.Storage)
uc := caches.NewUserCache(userID, h.GlobalCache, h.Storage, h.PollerMap)
// select all non-zero highlight or notif counts and set them, as this is less costly than looping every room/user pair
err := h.Storage.UnreadTable.SelectAllNonZeroCountsForUser(userID, func(roomID string, highlightCount, notificationCount int) {
uc.OnUnreadCounts(roomID, &highlightCount, &notificationCount)
Expand Down

0 comments on commit 2920191

Please sign in to comment.