Skip to content

Commit

Permalink
Have WhoAmI extract the device_id
Browse files Browse the repository at this point in the history
Useful for #51, small enough to include in isolation
  • Loading branch information
David Robertson committed Apr 11, 2023
1 parent 53f6d5e commit 846197e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
20 changes: 12 additions & 8 deletions sync2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ var ProxyVersion = ""
var HTTP401 error = fmt.Errorf("HTTP 401")

type Client interface {
WhoAmI(accessToken string) (string, error)
// WhoAmI asks the homeserver to lookup the access token using the CSAPI /whoami
// endpoint. The response must contain a device ID (meaning that we assume the
// homeserver supports Matrix >= 1.1.)
WhoAmI(accessToken string) (userID, deviceID string, err error)
DoSyncV2(ctx context.Context, accessToken, since string, isFirst bool, toDeviceOnly bool) (*SyncResponse, int, error)
}

Expand All @@ -30,29 +33,30 @@ type HTTPClient struct {
}

// Return sync2.HTTP401 if this request returns 401
func (v *HTTPClient) WhoAmI(accessToken string) (string, error) {
func (v *HTTPClient) WhoAmI(accessToken string) (string, string, error) {
req, err := http.NewRequest("GET", v.DestinationServer+"/_matrix/client/r0/account/whoami", nil)
if err != nil {
return "", err
return "", "", err
}
req.Header.Set("User-Agent", "sync-v3-proxy-"+ProxyVersion)
req.Header.Set("Authorization", "Bearer "+accessToken)
res, err := v.Client.Do(req)
if err != nil {
return "", err
return "", "", err
}
if res.StatusCode != 200 {
if res.StatusCode == 401 {
return "", HTTP401
return "", "", HTTP401
}
return "", fmt.Errorf("/whoami returned HTTP %d", res.StatusCode)
return "", "", fmt.Errorf("/whoami returned HTTP %d", res.StatusCode)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return "", err
return "", "", err
}
return gjson.GetBytes(body, "user_id").Str, nil
response := gjson.ParseBytes(body)
return response.Get("user_id").Str, response.Get("device_id").Str, nil
}

// DoSyncV2 performs a sync v2 request. Returns the sync response and the response status code
Expand Down
4 changes: 2 additions & 2 deletions sync2/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,8 @@ type mockClient struct {
func (c *mockClient) DoSyncV2(ctx context.Context, authHeader, since string, isFirst, toDeviceOnly bool) (*SyncResponse, int, error) {
return c.fn(authHeader, since)
}
func (c *mockClient) WhoAmI(authHeader string) (string, error) {
return "@alice:localhost", nil
func (c *mockClient) WhoAmI(authHeader string) (string, string, error) {
return "@alice:localhost", "device_123", nil
}

type mockDataReceiver struct {
Expand Down
2 changes: 1 addition & 1 deletion sync3/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
}
}
if v2device.UserID == "" {
v2device.UserID, err = h.V2.WhoAmI(accessToken)
v2device.UserID, _, err = h.V2.WhoAmI(accessToken)
if err != nil {
if err == sync2.HTTP401 {
return nil, &internal.HandlerError{
Expand Down

0 comments on commit 846197e

Please sign in to comment.