Skip to content

Commit

Permalink
Make ConnID hold a UserID
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson committed Apr 28, 2023
1 parent adb3ba3 commit ca8a2d7
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 26 deletions.
2 changes: 2 additions & 0 deletions pubsub/v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type V2InitialSyncComplete struct {
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }

type V2DeviceData struct {
UserID string
DeviceID string
Pos int64
}
Expand Down Expand Up @@ -106,6 +107,7 @@ type V2DeviceMessages struct {
func (*V2DeviceMessages) Type() string { return "V2DeviceMessages" }

type V2ExpiredToken struct {
UserID string
DeviceID string
}

Expand Down
2 changes: 2 additions & 0 deletions sync2/handler2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ func (h *Handler) OnExpiredToken(userID, deviceID string) {
h.Store.DeviceDataTable.DeleteDevice(userID, deviceID)
// also notify v3 side so it can remove the connection from ConnMap
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2ExpiredToken{
UserID: userID,
DeviceID: deviceID,
})
}
Expand Down Expand Up @@ -201,6 +202,7 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int,
return
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
UserID: userID,
DeviceID: deviceID,
Pos: nextPos,
})
Expand Down
14 changes: 5 additions & 9 deletions sync3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
)

type ConnID struct {
UserID string
DeviceID string
}

func (c *ConnID) String() string {
return c.DeviceID
return c.UserID + "|" + c.DeviceID
}

type ConnHandler interface {
Expand All @@ -24,7 +25,6 @@ type ConnHandler interface {
// status code to send back.
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
OnUpdate(ctx context.Context, update caches.Update)
UserID() string
Destroy()
Alive() bool
}
Expand All @@ -33,7 +33,7 @@ type ConnHandler interface {
// of the /sync request, including sending cached data in the event of retries. It does not handle
// the contents of the data at all.
type Conn struct {
ConnID ConnID
ConnID

handler ConnHandler

Expand Down Expand Up @@ -65,10 +65,6 @@ func NewConn(connID ConnID, h ConnHandler) *Conn {
}
}

func (c *Conn) UserID() string {
return c.handler.UserID()
}

func (c *Conn) Alive() bool {
return c.handler.Alive()
}
Expand Down Expand Up @@ -105,7 +101,7 @@ func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err
}
ctx, task := internal.StartTask(ctx, taskType)
defer task.End()
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.handler.UserID(), c.ConnID.DeviceID, req.pos)
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.UserID, c.ConnID.DeviceID, req.pos)
return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0)
}

Expand Down Expand Up @@ -164,7 +160,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Respo
c.serverResponses = c.serverResponses[delIndex+1:] // slice out the first delIndex+1 elements

defer func() {
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.handler.UserID())
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.UserID)
if nextUnACKedResponse != nil {
l.Int64("new_pos", nextUnACKedResponse.PosInt())
}
Expand Down
14 changes: 7 additions & 7 deletions sync3/connmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
conn = NewConn(cid, h)
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn
m.userIDToConn[h.UserID()] = append(m.userIDToConn[h.UserID()], conn)
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
return conn, true
}

Expand All @@ -94,20 +94,20 @@ func (m *ConnMap) closeConn(conn *Conn) {
return
}

connID := conn.ConnID.String()
logger.Trace().Str("conn", connID).Msg("closing connection")
connKey := conn.ConnID.String()
logger.Trace().Str("conn", connKey).Msg("closing connection")
// remove conn from all the maps
delete(m.connIDToConn, connID)
delete(m.connIDToConn, connKey)
h := conn.handler
conns := m.userIDToConn[h.UserID()]
conns := m.userIDToConn[conn.UserID]
for i := 0; i < len(conns); i++ {
if conns[i].ConnID.String() == connID {
if conns[i].DeviceID == conn.DeviceID {
// delete without preserving order
conns[i] = conns[len(conns)-1]
conns = conns[:len(conns)-1]
}
}
m.userIDToConn[h.UserID()] = conns
m.userIDToConn[conn.UserID] = conns
// remove user cache listeners etc
h.Destroy()
}
19 changes: 9 additions & 10 deletions sync3/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
return herr
}
requestBody.SetPos(cpos)
internal.SetRequestContextUserID(req.Context(), conn.UserID())
log := hlog.FromRequest(req).With().Str("user", conn.UserID()).Int64("pos", cpos).Logger()
internal.SetRequestContextUserID(req.Context(), conn.UserID)
log := hlog.FromRequest(req).With().Str("user", conn.UserID).Int64("pos", cpos).Logger()

var timeout int
if req.URL.Query().Get("timeout") == "" {
Expand Down Expand Up @@ -320,7 +320,10 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
log.Warn().Msg("Unable to update last seen timestamp")
}

connID := connIDFromToken(token)
connID := sync3.ConnID{
UserID: token.UserID,
DeviceID: token.DeviceID,
}
// client thinks they have a connection
if containsPos {
// Lookup the connection
Expand Down Expand Up @@ -375,13 +378,6 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
return conn, nil
}

func connIDFromToken(token *sync2.Token) sync3.ConnID {
return sync3.ConnID{
// TODO: change ConnID to be a (user, device) ID pair
DeviceID: token.DeviceID,
}
}

func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string) (*sync2.Token, *internal.HandlerError) {
// We don't recognise the given accessToken. Ask the homeserver who owns it.
userID, deviceID, err := h.V2.WhoAmI(accessToken)
Expand Down Expand Up @@ -595,6 +591,7 @@ func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
defer task.End()
conn := h.ConnMap.Conn(sync3.ConnID{
UserID: p.UserID,
DeviceID: p.DeviceID,
})
if conn == nil {
Expand All @@ -607,6 +604,7 @@ func (h *SyncLiveHandler) OnDeviceMessages(p *pubsub.V2DeviceMessages) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceMessages")
defer task.End()
conn := h.ConnMap.Conn(sync3.ConnID{
UserID: p.UserID,
DeviceID: p.DeviceID,
})
if conn == nil {
Expand Down Expand Up @@ -703,6 +701,7 @@ func (h *SyncLiveHandler) OnAccountData(p *pubsub.V2AccountData) {

func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
h.ConnMap.CloseConn(sync3.ConnID{
UserID: p.UserID,
DeviceID: p.DeviceID,
})
}
Expand Down

0 comments on commit ca8a2d7

Please sign in to comment.