Skip to content

Commit

Permalink
Add support for MSC3202 in appservice package
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jul 14, 2021
1 parent 18c5531 commit 72aa965
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 59 deletions.
16 changes: 8 additions & 8 deletions appservice/appservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (

// EventChannelSize is the size for the Events channel in Appservice instances.
var EventChannelSize = 64
var OTKChannelSize = 4

// Create a blank appservice instance.
func Create() *AppService {
Expand Down Expand Up @@ -83,20 +84,17 @@ type AppService struct {
RegistrationPath string `yaml:"registration"`
Host HostConfig `yaml:"host"`
LogConfig LogConfig `yaml:"logging"`
Sync struct {
Enabled bool `yaml:"enabled"`
FilterID string `yaml:"filter_id"`
NextBatch string `yaml:"next_batch"`
} `yaml:"sync"`

Registration *Registration `yaml:"-"`
Log maulogger.Logger `yaml:"-"`

lastProcessedTransaction string

Events chan *event.Event `yaml:"-"`
QueryHandler QueryHandler `yaml:"-"`
StateStore StateStore `yaml:"-"`
Events chan *event.Event `yaml:"-"`
DeviceLists chan *mautrix.DeviceLists `yaml:"-"`
OTKCounts chan *mautrix.OTKCount `yaml:"-"`
QueryHandler QueryHandler `yaml:"-"`
StateStore StateStore `yaml:"-"`

Router *mux.Router `yaml:"-"`
UserAgent string `yaml:"-"`
Expand Down Expand Up @@ -246,6 +244,8 @@ func (as *AppService) BotClient() *mautrix.Client {
// Init initializes the logger and loads the registration of this appservice.
func (as *AppService) Init() (bool, error) {
as.Events = make(chan *event.Event, EventChannelSize)
as.OTKCounts = make(chan *mautrix.OTKCount, OTKChannelSize)
as.DeviceLists = make(chan *mautrix.DeviceLists, EventChannelSize)
as.QueryHandler = &QueryHandlerStub{}

if len(as.UserAgent) == 0 {
Expand Down
57 changes: 51 additions & 6 deletions appservice/eventprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

log "maunium.net/go/maulogger/v2"

"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
)

Expand All @@ -24,6 +25,8 @@ const (
)

type EventHandler func(evt *event.Event)
type OTKHandler func(otk *mautrix.OTKCount)
type DeviceListHandler func(otk *mautrix.DeviceLists, since string)

type EventProcessor struct {
ExecMode ExecMode
Expand All @@ -32,6 +35,9 @@ type EventProcessor struct {
log log.Logger
stop chan struct{}
handlers map[event.Type][]EventHandler

otkHandlers []OTKHandler
deviceListHandlers []DeviceListHandler
}

func NewEventProcessor(as *AppService) *EventProcessor {
Expand All @@ -41,6 +47,9 @@ func NewEventProcessor(as *AppService) *EventProcessor {
log: as.Log.Sub("Events"),
stop: make(chan struct{}, 1),
handlers: make(map[event.Type][]EventHandler),

otkHandlers: make([]OTKHandler, 0),
deviceListHandlers: make([]DeviceListHandler, 0),
}
}

Expand All @@ -54,16 +63,48 @@ func (ep *EventProcessor) On(evtType event.Type, handler EventHandler) {
ep.handlers[evtType] = handlers
}

func (ep *EventProcessor) OnOTK(handler OTKHandler) {
ep.otkHandlers = append(ep.otkHandlers, handler)
}

func (ep *EventProcessor) OnDeviceList(handler DeviceListHandler) {
ep.deviceListHandlers = append(ep.deviceListHandlers, handler)
}

func (ep *EventProcessor) recoverFunc(data interface{}) {
if err := recover(); err != nil {
d, _ := json.Marshal(data)
ep.log.Errorfln("Panic in Matrix event handler: %v (event content: %s):\n%s", err, string(d), string(debug.Stack()))
}
}

func (ep *EventProcessor) callHandler(handler EventHandler, evt *event.Event) {
defer func() {
if err := recover(); err != nil {
d, _ := json.Marshal(evt)
ep.log.Errorfln("Panic in Matrix event handler: %v (event content: %s):\n%s", err, string(d), string(debug.Stack()))
}
}()
defer ep.recoverFunc(evt)
handler(evt)
}

func (ep *EventProcessor) callOTKHandler(handler OTKHandler, otk *mautrix.OTKCount) {
defer ep.recoverFunc(otk)
handler(otk)
}

func (ep *EventProcessor) callDeviceListHandler(handler DeviceListHandler, dl *mautrix.DeviceLists) {
defer ep.recoverFunc(dl)
handler(dl, "")
}

func (ep *EventProcessor) DispatchOTK(otk *mautrix.OTKCount) {
for _, handler := range ep.otkHandlers {
go ep.callOTKHandler(handler, otk)
}
}

func (ep *EventProcessor) DispatchDeviceList(dl *mautrix.DeviceLists) {
for _, handler := range ep.deviceListHandlers {
go ep.callDeviceListHandler(handler, dl)
}
}

func (ep *EventProcessor) Dispatch(evt *event.Event) {
handlers, ok := ep.handlers[evt.Type]
if !ok {
Expand Down Expand Up @@ -92,6 +133,10 @@ func (ep *EventProcessor) Start() {
select {
case evt := <-ep.as.Events:
ep.Dispatch(evt)
case otk := <-ep.as.OTKCounts:
ep.DispatchOTK(otk)
case dl := <-ep.as.DeviceLists:
ep.DispatchDeviceList(dl)
case <-ep.stop:
return
}
Expand Down
56 changes: 44 additions & 12 deletions appservice/http.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand All @@ -17,11 +17,12 @@ import (

"github.com/gorilla/mux"

"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)

// Listen starts the HTTP server that listens for calls from the Matrix homeserver.
// Start starts the HTTP server that listens for calls from the Matrix homeserver.
func (as *AppService) Start() {
as.Router.HandleFunc("/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
as.Router.HandleFunc("/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
Expand Down Expand Up @@ -119,8 +120,8 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
return
}

eventList := EventList{}
err = json.Unmarshal(body, &eventList)
var txn Transaction
err = json.Unmarshal(body, &txn)
if err != nil {
as.Log.Warnfln("Failed to parse JSON of transaction %s: %v", txnID, err)
Error{
Expand All @@ -130,22 +131,53 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
}.Write(w)
} else {
if as.Registration.EphemeralEvents {
if eventList.EphemeralEvents != nil {
as.handleEvents(eventList.EphemeralEvents, event.EphemeralEventType)
} else if eventList.SoruEphemeralEvents != nil {
as.handleEvents(eventList.SoruEphemeralEvents, event.EphemeralEventType)
if txn.EphemeralEvents != nil {
as.handleEvents(txn.EphemeralEvents, event.EphemeralEventType)
} else if txn.MSC2409EphemeralEvents != nil {
as.handleEvents(txn.MSC2409EphemeralEvents, event.EphemeralEventType)
}
}
as.handleEvents(eventList.Events, event.UnknownEventType)
as.handleEvents(txn.Events, event.UnknownEventType)
if txn.DeviceLists != nil {
as.handleDeviceLists(txn.DeviceLists)
} else if txn.MSC3202DeviceLists != nil {
as.handleDeviceLists(txn.MSC3202DeviceLists)
}
if txn.DeviceOTKCount != nil {
as.handleOTKCounts(txn.DeviceOTKCount)
} else if txn.MSC3202DeviceOTKCount != nil {
as.handleOTKCounts(txn.MSC3202DeviceOTKCount)
}
WriteBlankOK(w)
}
as.lastProcessedTransaction = txnID
}

func (as *AppService) handleEvents(evts []*event.Event, typeClass event.TypeClass) {
func (as *AppService) handleOTKCounts(otks map[id.UserID]mautrix.OTKCount) {
for userID, otkCounts := range otks {
otkCounts.UserID = userID
select {
case as.OTKCounts <- &otkCounts:
default:
as.Log.Warnfln("Dropped OTK count update for %s because channel is full", userID)
}
}
}

func (as *AppService) handleDeviceLists(dl *mautrix.DeviceLists) {
select {
case as.DeviceLists <- dl:
default:
as.Log.Warnln("Dropped device list update because channel is full")
}
}

func (as *AppService) handleEvents(evts []*event.Event, defaultTypeClass event.TypeClass) {
for _, evt := range evts {
if typeClass != event.UnknownEventType {
evt.Type.Class = typeClass
if len(evt.ToUserID) > 0 {
evt.Type.Class = event.ToDeviceEventType
} else if defaultTypeClass != event.UnknownEventType {
evt.Type.Class = defaultTypeClass
} else if evt.StateKey != nil {
evt.Type.Class = event.StateEventType
} else {
Expand Down
19 changes: 13 additions & 6 deletions appservice/protocol.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand All @@ -10,14 +10,21 @@ import (
"encoding/json"
"net/http"

"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)

// EventList contains a list of events.
type EventList struct {
Events []*event.Event `json:"events"`
EphemeralEvents []*event.Event `json:"ephemeral"`
SoruEphemeralEvents []*event.Event `json:"de.sorunome.msc2409.ephemeral"`
// Transaction contains a list of events.
type Transaction struct {
Events []*event.Event `json:"events"`
EphemeralEvents []*event.Event `json:"ephemeral,omitempty"`
DeviceLists *mautrix.DeviceLists `json:"device_lists,omitempty"`
DeviceOTKCount map[id.UserID]mautrix.OTKCount `json:"device_one_time_keys_count,omitempty"`

MSC2409EphemeralEvents []*event.Event `json:"de.sorunome.msc2409.ephemeral,omitempty"`
MSC3202DeviceLists *mautrix.DeviceLists `json:"org.matrix.msc3202.device_lists,omitempty"`
MSC3202DeviceOTKCount map[id.UserID]mautrix.OTKCount `json:"org.matrix.msc3202.device_one_time_keys_count,omitempty"`
}

// EventListener is a function that receives events.
Expand Down
11 changes: 3 additions & 8 deletions appservice/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ import (
"maunium.net/go/mautrix/event"
)

type ErrorResponse struct {
ErrorCode ErrorCode `json:"errcode"`
Error string `json:"error"`
}

type WebsocketCommand struct {
ReqID int `json:"id,omitempty"`
Command string `json:"command"`
Expand All @@ -34,7 +29,7 @@ type WebsocketCommand struct {
type WebsocketTransaction struct {
Status string `json:"status"`
TxnID string `json:"txn_id"`
EventList
Transaction
}

type WebsocketMessage struct {
Expand Down Expand Up @@ -150,12 +145,12 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
"User-Agent": []string{as.BotClient().UserAgent},
})
if resp != nil && resp.StatusCode >= 400 {
var errResp ErrorResponse
var errResp Error
err = json.NewDecoder(resp.Body).Decode(&errResp)
if err != nil {
return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode)
} else {
return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Error)
return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message)
}
} else if err != nil {
return fmt.Errorf("failed to open websocket: %w", err)
Expand Down
54 changes: 42 additions & 12 deletions crypto/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"sync"
"time"

"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/crypto/ssss"
"maunium.net/go/mautrix/id"

Expand Down Expand Up @@ -153,16 +154,47 @@ func (mach *OlmMachine) OwnIdentity() *DeviceIdentity {
}
}

func (mach *OlmMachine) AddAppserviceListener(ep *appservice.EventProcessor, az *appservice.AppService) {
// ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events
ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceRoomKeyWithheld, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceOrgMatrixRoomKeyWithheld, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationRequest, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationStart, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationAccept, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationKey, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationMAC, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationCancel, mach.HandleToDeviceEvent)
ep.OnOTK(mach.HandleOTKCounts)
ep.OnDeviceList(mach.HandleDeviceLists)
}

func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) {
if len(dl.Changed) > 0 {
mach.Log.Trace("Device list changes in /sync: %v", dl.Changed)
mach.fetchKeys(dl.Changed, since, false)
}
}

func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2
if otkCount.SignedCurve25519 < int(minCount) {
mach.Log.Debug("Sync response said we have %d signed curve25519 keys left, sharing new ones...", otkCount.SignedCurve25519)
err := mach.ShareKeys(otkCount.SignedCurve25519)
if err != nil {
mach.Log.Error("Failed to share keys: %v", err)
}
}
}

// ProcessSyncResponse processes a single /sync response.
//
// This can be easily registered into a mautrix client using .OnSync():
//
// client.Syncer.(*mautrix.DefaultSyncer).OnSync(c.crypto.ProcessSyncResponse)
func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool {
if len(resp.DeviceLists.Changed) > 0 {
mach.Log.Trace("Device list changes in /sync: %v", resp.DeviceLists.Changed)
mach.fetchKeys(resp.DeviceLists.Changed, since, false)
}
mach.HandleDeviceLists(&resp.DeviceLists, since)

for _, evt := range resp.ToDevice.Events {
evt.Type.Class = event.ToDeviceEventType
Expand All @@ -174,14 +206,7 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
mach.HandleToDeviceEvent(evt)
}

min := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2
if resp.DeviceOneTimeKeysCount.SignedCurve25519 < int(min) {
mach.Log.Debug("Sync response said we have %d signed curve25519 keys left, sharing new ones...", resp.DeviceOneTimeKeysCount.SignedCurve25519)
err := mach.ShareKeys(resp.DeviceOneTimeKeysCount.SignedCurve25519)
if err != nil {
mach.Log.Error("Failed to share keys: %v", err)
}
}
mach.HandleOTKCounts(&resp.DeviceOTKCount)
return true
}

Expand Down Expand Up @@ -222,6 +247,11 @@ func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
// HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you
// don't need to add any custom handlers if you use that method.
func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) {
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Debug("Dropping to-device event targeted to %s/%s (not us)", evt.ToUserID, evt.ToDeviceID)
return
}
switch content := evt.Content.Parsed.(type) {
case *event.EncryptedEventContent:
mach.Log.Debug("Handling encrypted to-device event from %s/%s", evt.Sender, content.SenderKey)
Expand Down
Loading

0 comments on commit 72aa965

Please sign in to comment.