From 3350232cf31acc4e624571df7ef0c3ba28fc7d60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20Junior?= Date: Wed, 7 Jul 2021 11:32:53 -0400 Subject: [PATCH] refactor(dot/rpc/subscription): refactor websocket `HandleComm` (#1673) --- dot/rpc/http.go | 4 + dot/rpc/subscription/listeners.go | 161 +++++++---- dot/rpc/subscription/subscription.go | 80 ++++++ dot/rpc/subscription/websocket.go | 353 ++++++++++++++----------- dot/rpc/subscription/websocket_test.go | 49 ++-- 5 files changed, 426 insertions(+), 221 deletions(-) create mode 100644 dot/rpc/subscription/subscription.go diff --git a/dot/rpc/http.go b/dot/rpc/http.go index 76038e3847..28ab690ab4 100644 --- a/dot/rpc/http.go +++ b/dot/rpc/http.go @@ -21,6 +21,7 @@ import ( "net" "net/http" "os" + "time" "github.com/ChainSafe/gossamer/dot/rpc/modules" "github.com/ChainSafe/gossamer/dot/rpc/subscription" @@ -242,6 +243,9 @@ func NewWSConn(conn *websocket.Conn, cfg *HTTPServerConfig) *subscription.WSConn CoreAPI: cfg.CoreAPI, TxStateAPI: cfg.TransactionQueueAPI, RPCHost: fmt.Sprintf("http://%s:%d/", cfg.Host, cfg.RPCPort), + HTTP: &http.Client{ + Timeout: time.Second * 30, + }, } return c } diff --git a/dot/rpc/subscription/listeners.go b/dot/rpc/subscription/listeners.go index a6b5d97abc..ba9e4ca7b6 100644 --- a/dot/rpc/subscription/listeners.go +++ b/dot/rpc/subscription/listeners.go @@ -16,6 +16,7 @@ package subscription import ( + "context" "fmt" "reflect" @@ -28,6 +29,7 @@ import ( // Listener interface for functions that define Listener related functions type Listener interface { Listen() + Stop() } // WSConnAPI interface defining methors a WSConn should have @@ -85,59 +87,98 @@ func (s *StorageObserver) GetFilter() map[string][]byte { // Listen to satisfy Listener interface (but is no longer used by StorageObserver) func (s *StorageObserver) Listen() {} +// Stop to satisfy Listener interface (but is no longer used by StorageObserver) +func (s *StorageObserver) Stop() {} + // BlockListener to handle listening for blocks importedChan type BlockListener struct { Channel chan *types.Block wsconn WSConnAPI ChanID byte subID uint + + ctx context.Context + cancel context.CancelFunc } // Listen implementation of Listen interface to listen for importedChan changes func (l *BlockListener) Listen() { - for block := range l.Channel { - if block == nil { - continue - } - head, err := modules.HeaderToJSON(*block.Header) - if err != nil { - logger.Error("failed to convert header to JSON", "error", err) - } + l.ctx, l.cancel = context.WithCancel(context.Background()) + go func() { + for { + select { + case <-l.ctx.Done(): + return + case block, ok := <-l.Channel: + if !ok { + return + } - res := newSubcriptionBaseResponseJSON() - res.Method = "chain_newHead" - res.Params.Result = head - res.Params.SubscriptionID = l.subID - l.wsconn.safeSend(res) - } + if block == nil { + continue + } + head, err := modules.HeaderToJSON(*block.Header) + if err != nil { + logger.Error("failed to convert header to JSON", "error", err) + } + + res := newSubcriptionBaseResponseJSON() + res.Method = "chain_newHead" + res.Params.Result = head + res.Params.SubscriptionID = l.subID + l.wsconn.safeSend(res) + } + } + }() } +// Stop to cancel the running goroutines to this listener +func (l *BlockListener) Stop() { l.cancel() } + // BlockFinalizedListener to handle listening for finalised blocks type BlockFinalizedListener struct { channel chan *types.FinalisationInfo wsconn WSConnAPI chanID byte subID uint + ctx context.Context + cancel context.CancelFunc } // Listen implementation of Listen interface to listen for importedChan changes func (l *BlockFinalizedListener) Listen() { - for info := range l.channel { - if info == nil || info.Header == nil { - continue - } - head, err := modules.HeaderToJSON(*info.Header) - if err != nil { - logger.Error("failed to convert header to JSON", "error", err) + l.ctx, l.cancel = context.WithCancel(context.Background()) + + go func() { + for { + select { + case <-l.ctx.Done(): + return + case info, ok := <-l.channel: + if !ok { + return + } + + if info == nil || info.Header == nil { + continue + } + head, err := modules.HeaderToJSON(*info.Header) + if err != nil { + logger.Error("failed to convert header to JSON", "error", err) + } + res := newSubcriptionBaseResponseJSON() + res.Method = "chain_finalizedHead" + res.Params.Result = head + res.Params.SubscriptionID = l.subID + l.wsconn.safeSend(res) + } } - res := newSubcriptionBaseResponseJSON() - res.Method = "chain_finalizedHead" - res.Params.Result = head - res.Params.SubscriptionID = l.subID - l.wsconn.safeSend(res) - } + }() } +// Stop to cancel the running goroutines to this listener +func (l *BlockFinalizedListener) Stop() { l.cancel() } + // ExtrinsicSubmitListener to handle listening for extrinsic events type ExtrinsicSubmitListener struct { wsconn WSConnAPI @@ -149,6 +190,9 @@ type ExtrinsicSubmitListener struct { importedHash common.Hash finalisedChan chan *types.FinalisationInfo finalisedChanID byte + + ctx context.Context + cancel context.CancelFunc } // AuthorExtrinsicUpdates method name @@ -156,39 +200,62 @@ const AuthorExtrinsicUpdates = "author_extrinsicUpdate" // Listen implementation of Listen interface to listen for importedChan changes func (l *ExtrinsicSubmitListener) Listen() { + l.ctx, l.cancel = context.WithCancel(context.Background()) + // listen for imported blocks with extrinsic go func() { - for block := range l.importedChan { - if block == nil { - continue - } - bodyHasExtrinsic, err := block.Body.HasExtrinsic(l.extrinsic) - if err != nil { - fmt.Printf("error %v\n", err) - } + for { + select { + case <-l.ctx.Done(): + return + case block, ok := <-l.importedChan: + if !ok { + return + } + + if block == nil { + continue + } + bodyHasExtrinsic, err := block.Body.HasExtrinsic(l.extrinsic) + if err != nil { + fmt.Printf("error %v\n", err) + } - if bodyHasExtrinsic { - resM := make(map[string]interface{}) - resM["inBlock"] = block.Header.Hash().String() + if bodyHasExtrinsic { + resM := make(map[string]interface{}) + resM["inBlock"] = block.Header.Hash().String() - l.importedHash = block.Header.Hash() - l.wsconn.safeSend(newSubscriptionResponse(AuthorExtrinsicUpdates, l.subID, resM)) + l.importedHash = block.Header.Hash() + l.wsconn.safeSend(newSubscriptionResponse(AuthorExtrinsicUpdates, l.subID, resM)) + } } } }() // listen for finalised headers go func() { - for info := range l.finalisedChan { - if reflect.DeepEqual(l.importedHash, info.Header.Hash()) { - resM := make(map[string]interface{}) - resM["finalised"] = info.Header.Hash().String() - l.wsconn.safeSend(newSubscriptionResponse(AuthorExtrinsicUpdates, l.subID, resM)) + for { + select { + case <-l.ctx.Done(): + return + case info, ok := <-l.finalisedChan: + if !ok { + return + } + + if reflect.DeepEqual(l.importedHash, info.Header.Hash()) { + resM := make(map[string]interface{}) + resM["finalised"] = info.Header.Hash().String() + l.wsconn.safeSend(newSubscriptionResponse(AuthorExtrinsicUpdates, l.subID, resM)) + } } } }() } +// Stop to cancel the running goroutines to this listener +func (l *ExtrinsicSubmitListener) Stop() { l.cancel() } + // RuntimeVersionListener to handle listening for Runtime Version type RuntimeVersionListener struct { wsconn *WSConn @@ -215,3 +282,7 @@ func (l *RuntimeVersionListener) Listen() { l.wsconn.safeSend(newSubscriptionResponse("state_runtimeVersion", l.subID, ver)) } + +// Stop to runtimeVersionListener not implemented yet because the listener +// does not need to be stoped +func (l *RuntimeVersionListener) Stop() {} diff --git a/dot/rpc/subscription/subscription.go b/dot/rpc/subscription/subscription.go new file mode 100644 index 0000000000..93413fec65 --- /dev/null +++ b/dot/rpc/subscription/subscription.go @@ -0,0 +1,80 @@ +package subscription + +import ( + "errors" + "fmt" + "strconv" +) + +var errUknownParamSubscribeID = errors.New("invalid params format type") +var errCannotParseID = errors.New("could not parse param id") +var errCannotFindListener = errors.New("could not find listener") +var errCannotFindUnsubsriber = errors.New("could not find unsubsriber function") + +type unsubListener func(reqid float64, l Listener, params interface{}) +type setupListener func(reqid float64, params interface{}) (Listener, error) + +func (c *WSConn) getSetupListener(method string) setupListener { + switch method { + case "chain_subscribeNewHeads", "chain_subscribeNewHead": + return c.initBlockListener + case "state_subscribeStorage": + return c.initStorageChangeListener + case "chain_subscribeFinalizedHeads": + return c.initBlockFinalizedListener + case "state_subscribeRuntimeVersion": + return c.initRuntimeVersionListener + default: + return nil + } +} + +func (c *WSConn) getUnsubListener(method string, params interface{}) (unsubListener, Listener, error) { + subscribeID, err := parseSubscribeID(params) + if err != nil { + return nil, nil, err + } + + listener, ok := c.Subscriptions[subscribeID] + if !ok { + return nil, nil, fmt.Errorf("subscriber id %v: %w", subscribeID, errCannotFindListener) + } + + var unsub unsubListener + + switch method { + case "state_unsubscribeStorage": + unsub = c.unsubscribeStorageListener + default: + return nil, nil, errCannotFindUnsubsriber + } + + return unsub, listener, nil +} + +func parseSubscribeID(p interface{}) (uint, error) { + switch v := p.(type) { + case []interface{}: + if len(v) == 0 { + return 0, errUknownParamSubscribeID + } + default: + return 0, errUknownParamSubscribeID + } + + var id uint + switch v := p.([]interface{})[0].(type) { + case float64: + id = uint(v) + case string: + i, err := strconv.ParseUint(v, 10, 32) + if err != nil { + return 0, errCannotParseID + } + id = uint(i) + default: + return 0, errUknownParamSubscribeID + } + + return id, nil +} diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index ab1b367449..e36bed8795 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -19,11 +19,11 @@ package subscription import ( "bytes" "encoding/json" + "errors" "fmt" "io/ioutil" "math/big" "net/http" - "strconv" "strings" "sync" @@ -35,6 +35,12 @@ import ( "github.com/gorilla/websocket" ) +type httpclient interface { + Do(*http.Request) (*http.Response, error) +} + +var errCannotReadFromWebsocket = errors.New("cannot read message from websocket") +var errCannotUnmarshalMessage = errors.New("cannot unmarshal webasocket message data") var logger = log.New("pkg", "rpc/subscription") // WSConn struct to hold WebSocket Connection references @@ -51,141 +57,130 @@ type WSConn struct { CoreAPI modules.CoreAPI TxStateAPI modules.TransactionStateAPI RPCHost string + + HTTP httpclient +} + +// readWebsocketMessage will read and parse the message data to a string->interface{} data +func (c *WSConn) readWebsocketMessage() ([]byte, map[string]interface{}, error) { + _, mbytes, err := c.Wsconn.ReadMessage() + if err != nil { + logger.Debug("websocket failed to read message", "error", err) + return nil, nil, errCannotReadFromWebsocket + } + + logger.Trace("websocket received", "message", mbytes) + + // determine if request is for subscribe method type + var msg map[string]interface{} + err = json.Unmarshal(mbytes, &msg) + + if err != nil { + logger.Debug("websocket failed to unmarshal request message", "error", err) + return nil, nil, errCannotUnmarshalMessage + } + + return mbytes, msg, nil } //HandleComm handles messages received on websocket connections func (c *WSConn) HandleComm() { for { - _, mbytes, err := c.Wsconn.ReadMessage() - if err != nil { - logger.Warn("websocket failed to read message", "error", err) + mbytes, msg, err := c.readWebsocketMessage() + if errors.Is(err, errCannotReadFromWebsocket) { return } - logger.Trace("websocket received", "message", mbytes) - // determine if request is for subscribe method type - var msg map[string]interface{} - err = json.Unmarshal(mbytes, &msg) - if err != nil { - logger.Warn("websocket failed to unmarshal request message", "error", err) - c.safeSendError(0, big.NewInt(-32600), "Invalid request") + if errors.Is(err, errCannotUnmarshalMessage) { + c.safeSendError(0, big.NewInt(InvalidRequestCode), InvalidRequestMessage) continue } - method := msg["method"] params := msg["params"] + reqid := msg["id"].(float64) + method := msg["method"].(string) + logger.Debug("ws method called", "method", method, "params", params) - // if method contains subscribe, then register subscription - if strings.Contains(fmt.Sprintf("%s", method), "subscribe") { - reqid := msg["id"].(float64) - switch method { - case "chain_subscribeNewHeads", "chain_subscribeNewHead": - bl, err1 := c.initBlockListener(reqid) - if err1 != nil { - logger.Warn("failed to create block listener", "error", err) - continue - } - c.startListener(bl) - case "state_subscribeStorage": - _, err2 := c.initStorageChangeListener(reqid, params) - if err2 != nil { - logger.Warn("failed to create state change listener", "error", err2) - continue - } + if strings.Contains(method, "_subscribe") { + setup := c.getSetupListener(method) + + listener, err := setup(reqid, params) //nolint + if err != nil { + logger.Warn("failed to create listener", "method", method, "error", err) + continue + } + + listener.Listen() + continue + } - case "chain_subscribeFinalizedHeads": - bfl, err3 := c.initBlockFinalizedListener(reqid) - if err3 != nil { - logger.Warn("failed to create block finalised", "error", err3) + if strings.Contains(method, "_unsubscribe") { + unsub, listener, err := c.getUnsubListener(method, params) //nolint + + if err != nil { + logger.Warn("failed to get unsubscriber", "method", method, "error", err) + + if errors.Is(err, errUknownParamSubscribeID) || errors.Is(err, errCannotFindUnsubsriber) { + c.safeSendError(reqid, big.NewInt(InvalidRequestCode), InvalidRequestMessage) continue } - c.startListener(bfl) - case "state_subscribeRuntimeVersion": - rvl, err4 := c.initRuntimeVersionListener(reqid) - if err4 != nil { - logger.Warn("failed to create runtime version listener", "error", err4) + + if errors.Is(err, errCannotParseID) || errors.Is(err, errCannotFindListener) { + c.safeSend(newBooleanResponseJSON(false, reqid)) continue } - c.startListener(rvl) - case "state_unsubscribeStorage": - c.unsubscribeStorageListener(reqid, params) - } + + unsub(reqid, listener, params) + listener.Stop() continue } - if strings.Contains(fmt.Sprintf("%s", method), "submitAndWatchExtrinsic") { - reqid := msg["id"].(float64) - params := msg["params"] - el, e := c.initExtrinsicWatch(reqid, params) - if e != nil { - c.safeSendError(reqid, nil, e.Error()) - } else { - c.startListener(el) + if strings.Contains(method, "submitAndWatchExtrinsic") { + listener, err := c.initExtrinsicWatch(reqid, params) //nolint + if err != nil { + logger.Warn("failed to create listener", "method", method, "error", err) + c.safeSendError(reqid, nil, err.Error()) + continue } + + listener.Listen() continue } // handle non-subscribe calls - client := &http.Client{} - buf := &bytes.Buffer{} - _, err = buf.Write(mbytes) - if err != nil { - logger.Warn("failed to write message to buffer", "error", err) - return - } - - req, err := http.NewRequest("POST", c.RPCHost, buf) + request, err := c.prepareRequest(mbytes) if err != nil { - logger.Warn("failed request to rpc service", "error", err) + logger.Warn("failed while preparing the request", "error", err) return } - req.Header.Set("Content-Type", "application/json;") - - res, err := client.Do(req) + var wsresponse interface{} + err = c.executeRequest(request, &wsresponse) if err != nil { - logger.Warn("websocket error calling rpc", "error", err) + logger.Warn("problems while executing the request", "error", err) return } - body, err := ioutil.ReadAll(res.Body) - if err != nil { - logger.Warn("error reading response body", "error", err) - return - } - - err = res.Body.Close() - if err != nil { - logger.Warn("error closing response body", "error", err) - return - } - var wsSend interface{} - err = json.Unmarshal(body, &wsSend) - if err != nil { - logger.Warn("error unmarshal rpc response", "error", err) - return - } - - c.safeSend(wsSend) + c.safeSend(wsresponse) } } -func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (uint, error) { +func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (Listener, error) { if c.StorageAPI == nil { c.safeSendError(reqID, nil, "error StorageAPI not set") - return 0, fmt.Errorf("error StorageAPI not set") + return nil, fmt.Errorf("error StorageAPI not set") } - myObs := &StorageObserver{ + stgobs := &StorageObserver{ filter: make(map[string][]byte), wsconn: c, } pA, ok := params.([]interface{}) if !ok { - return 0, fmt.Errorf("unknown parameter type") + return nil, fmt.Errorf("unknown parameter type") } for _, param := range pA { switch p := param.(type) { @@ -193,59 +188,34 @@ func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (u for _, pp := range param.([]interface{}) { data, ok := pp.(string) if !ok { - return 0, fmt.Errorf("unknown parameter type") + return nil, fmt.Errorf("unknown parameter type") } - myObs.filter[data] = []byte{} + stgobs.filter[data] = []byte{} } case string: - myObs.filter[p] = []byte{} + stgobs.filter[p] = []byte{} default: - return 0, fmt.Errorf("unknown parameter type") + return nil, fmt.Errorf("unknown parameter type") } } - c.qtyListeners++ - myObs.id = c.qtyListeners + c.mu.Lock() - c.StorageAPI.RegisterStorageObserver(myObs) + c.qtyListeners++ + stgobs.id = c.qtyListeners + c.Subscriptions[stgobs.id] = stgobs - c.Subscriptions[myObs.id] = myObs + c.mu.Unlock() - initRes := NewSubscriptionResponseJSON(myObs.id, reqID) + c.StorageAPI.RegisterStorageObserver(stgobs) + initRes := NewSubscriptionResponseJSON(stgobs.id, reqID) c.safeSend(initRes) - return myObs.id, nil + return stgobs, nil } -func (c *WSConn) unsubscribeStorageListener(reqID float64, params interface{}) { - switch v := params.(type) { - case []interface{}: - if len(v) == 0 { - c.safeSendError(reqID, big.NewInt(InvalidRequestCode), InvalidRequestMessage) - return - } - default: - c.safeSendError(reqID, big.NewInt(InvalidRequestCode), InvalidRequestMessage) - return - } - - var id uint - switch v := params.([]interface{})[0].(type) { - case float64: - id = uint(v) - case string: - i, err := strconv.ParseUint(v, 10, 32) - if err != nil { - c.safeSend(newBooleanResponseJSON(false, reqID)) - return - } - id = uint(i) - default: - c.safeSendError(reqID, big.NewInt(InvalidRequestCode), InvalidRequestMessage) - return - } - - observer, ok := c.Subscriptions[id].(state.Observer) +func (c *WSConn) unsubscribeStorageListener(reqID float64, l Listener, _ interface{}) { + observer, ok := l.(state.Observer) if !ok { initRes := newBooleanResponseJSON(false, reqID) c.safeSend(initRes) @@ -256,7 +226,7 @@ func (c *WSConn) unsubscribeStorageListener(reqID float64, params interface{}) { c.safeSend(newBooleanResponseJSON(true, reqID)) } -func (c *WSConn) initBlockListener(reqID float64) (uint, error) { +func (c *WSConn) initBlockListener(reqID float64, _ interface{}) (Listener, error) { bl := &BlockListener{ Channel: make(chan *types.Block), wsconn: c, @@ -264,24 +234,31 @@ func (c *WSConn) initBlockListener(reqID float64) (uint, error) { if c.BlockAPI == nil { c.safeSendError(reqID, nil, "error BlockAPI not set") - return 0, fmt.Errorf("error BlockAPI not set") + return nil, fmt.Errorf("error BlockAPI not set") } - chanID, err := c.BlockAPI.RegisterImportedChannel(bl.Channel) + + var err error + bl.ChanID, err = c.BlockAPI.RegisterImportedChannel(bl.Channel) + if err != nil { - return 0, err + return nil, err } - bl.ChanID = chanID + + c.mu.Lock() + c.qtyListeners++ bl.subID = c.qtyListeners c.Subscriptions[bl.subID] = bl - c.BlockSubChannels[bl.subID] = chanID - initRes := NewSubscriptionResponseJSON(bl.subID, reqID) - c.safeSend(initRes) + c.BlockSubChannels[bl.subID] = bl.ChanID + + c.mu.Unlock() + + c.safeSend(NewSubscriptionResponseJSON(bl.subID, reqID)) - return bl.subID, nil + return bl, nil } -func (c *WSConn) initBlockFinalizedListener(reqID float64) (uint, error) { +func (c *WSConn) initBlockFinalizedListener(reqID float64, _ interface{}) (Listener, error) { bfl := &BlockFinalizedListener{ channel: make(chan *types.FinalisationInfo), wsconn: c, @@ -289,28 +266,35 @@ func (c *WSConn) initBlockFinalizedListener(reqID float64) (uint, error) { if c.BlockAPI == nil { c.safeSendError(reqID, nil, "error BlockAPI not set") - return 0, fmt.Errorf("error BlockAPI not set") + return nil, fmt.Errorf("error BlockAPI not set") } - chanID, err := c.BlockAPI.RegisterFinalizedChannel(bfl.channel) + + var err error + bfl.chanID, err = c.BlockAPI.RegisterFinalizedChannel(bfl.channel) if err != nil { - return 0, err + return nil, err } - bfl.chanID = chanID + + c.mu.Lock() + c.qtyListeners++ bfl.subID = c.qtyListeners c.Subscriptions[bfl.subID] = bfl - c.BlockSubChannels[bfl.subID] = chanID + c.BlockSubChannels[bfl.subID] = bfl.chanID + + c.mu.Unlock() + initRes := NewSubscriptionResponseJSON(bfl.subID, reqID) c.safeSend(initRes) - return bfl.subID, nil + return bfl, nil } -func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (uint, error) { +func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener, error) { pA := params.([]interface{}) extBytes, err := common.HexToBytes(pA[0].(string)) if err != nil { - return 0, err + return nil, err } // listen for built blocks @@ -322,26 +306,30 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (uint, er } if c.BlockAPI == nil { - return 0, fmt.Errorf("error BlockAPI not set") + return nil, fmt.Errorf("error BlockAPI not set") } esl.importedChanID, err = c.BlockAPI.RegisterImportedChannel(esl.importedChan) if err != nil { - return 0, err + return nil, err } esl.finalisedChanID, err = c.BlockAPI.RegisterFinalizedChannel(esl.finalisedChan) if err != nil { - return 0, err + return nil, err } + c.mu.Lock() + c.qtyListeners++ esl.subID = c.qtyListeners c.Subscriptions[esl.subID] = esl c.BlockSubChannels[esl.subID] = esl.importedChanID + c.mu.Unlock() + err = c.CoreAPI.HandleSubmittedExtrinsic(extBytes) if err != nil { - return 0, err + return nil, err } c.safeSend(NewSubscriptionResponseJSON(esl.subID, reqID)) @@ -350,24 +338,30 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (uint, er c.safeSend(newSubscriptionResponse(AuthorExtrinsicUpdates, esl.subID, "ready")) // todo (ed) determine which peer extrinsic has been broadcast to, and set status - return esl.subID, err + return esl, err } -func (c *WSConn) initRuntimeVersionListener(reqID float64) (uint, error) { +func (c *WSConn) initRuntimeVersionListener(reqID float64, _ interface{}) (Listener, error) { rvl := &RuntimeVersionListener{ wsconn: c, } + if c.CoreAPI == nil { c.safeSendError(reqID, nil, "error CoreAPI not set") - return 0, fmt.Errorf("error CoreAPI not set") + return nil, fmt.Errorf("error CoreAPI not set") } + + c.mu.Lock() + c.qtyListeners++ rvl.subID = c.qtyListeners c.Subscriptions[rvl.subID] = rvl - initRes := NewSubscriptionResponseJSON(rvl.subID, reqID) - c.safeSend(initRes) - return rvl.subID, nil + c.mu.Unlock() + + c.safeSend(NewSubscriptionResponseJSON(rvl.subID, reqID)) + + return rvl, nil } func (c *WSConn) safeSend(msg interface{}) { @@ -378,6 +372,7 @@ func (c *WSConn) safeSend(msg interface{}) { logger.Debug("error sending websocket message", "error", err) } } + func (c *WSConn) safeSendError(reqID float64, errorCode *big.Int, message string) { res := &ErrorResponseJSON{ Jsonrpc: "2.0", @@ -395,6 +390,52 @@ func (c *WSConn) safeSendError(reqID float64, errorCode *big.Int, message string } } +func (c *WSConn) prepareRequest(b []byte) (*http.Request, error) { + buff := &bytes.Buffer{} + if _, err := buff.Write(b); err != nil { + logger.Warn("failed to write message to buffer", "error", buff) + return nil, err + } + + req, err := http.NewRequest("POST", c.RPCHost, buff) + if err != nil { + logger.Warn("failed request to rpc service", "error", err) + return nil, err + } + + req.Header.Set("Content-Type", "application/json;") + return req, nil +} + +func (c *WSConn) executeRequest(r *http.Request, d interface{}) error { + res, err := c.HTTP.Do(r) + if err != nil { + logger.Warn("websocket error calling rpc", "error", err) + return err + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + logger.Warn("error reading response body", "error", err) + return err + } + + err = res.Body.Close() + if err != nil { + logger.Warn("error closing response body", "error", err) + return err + } + + err = json.Unmarshal(body, d) + + if err != nil { + logger.Warn("error unmarshal rpc response", "error", err) + return err + } + + return nil +} + // ErrorResponseJSON json for error responses type ErrorResponseJSON struct { Jsonrpc string `json:"jsonrpc"` @@ -407,7 +448,3 @@ type ErrorMessageJSON struct { Code *big.Int `json:"code"` Message string `json:"message"` } - -func (c *WSConn) startListener(lid uint) { - go c.Subscriptions[lid].Listen() -} diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index 5244a9d39f..cb2ab48906 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -15,6 +15,7 @@ import ( var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } + var wsconn = &WSConn{ Subscriptions: make(map[uint]Listener), BlockSubChannels: make(map[uint]byte), @@ -42,6 +43,7 @@ func TestMain(m *testing.M) { } }() time.Sleep(time.Millisecond * 100) + // Start all tests os.Exit(m.Run()) } @@ -55,8 +57,9 @@ func TestWSConn_HandleComm(t *testing.T) { // test storageChangeListener res, err := wsconn.initStorageChangeListener(1, nil) + require.Nil(t, res) + require.Len(t, wsconn.Subscriptions, 0) require.EqualError(t, err, "error StorageAPI not set") - require.Equal(t, uint(0), res) _, msg, err := c.ReadMessage() require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","error":{"code":null,"message":"error StorageAPI not set"},"id":1}`+"\n"), msg) @@ -64,19 +67,22 @@ func TestWSConn_HandleComm(t *testing.T) { wsconn.StorageAPI = modules.NewMockStorageAPI() res, err = wsconn.initStorageChangeListener(1, nil) + require.Nil(t, res) + require.Len(t, wsconn.Subscriptions, 0) require.EqualError(t, err, "unknown parameter type") - require.Equal(t, uint(0), res) res, err = wsconn.initStorageChangeListener(2, []interface{}{}) + require.NotNil(t, res) require.NoError(t, err) - require.Equal(t, uint(1), res) + require.Len(t, wsconn.Subscriptions, 1) _, msg, err = c.ReadMessage() require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","result":1,"id":2}`+"\n"), msg) res, err = wsconn.initStorageChangeListener(3, []interface{}{"0x26aa"}) + require.NotNil(t, res) require.NoError(t, err) - require.Equal(t, uint(2), res) + require.Len(t, wsconn.Subscriptions, 2) _, msg, err = c.ReadMessage() require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","result":2,"id":3}`+"\n"), msg) @@ -84,8 +90,9 @@ func TestWSConn_HandleComm(t *testing.T) { var testFilters = []interface{}{} var testFilter1 = []interface{}{"0x26aa", "0x26a1"} res, err = wsconn.initStorageChangeListener(4, append(testFilters, testFilter1)) + require.NotNil(t, res) require.NoError(t, err) - require.Equal(t, uint(3), res) + require.Len(t, wsconn.Subscriptions, 3) _, msg, err = c.ReadMessage() require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","result":3,"id":4}`+"\n"), msg) @@ -93,11 +100,14 @@ func TestWSConn_HandleComm(t *testing.T) { var testFilterWrongType = []interface{}{"0x26aa", 1} res, err = wsconn.initStorageChangeListener(5, append(testFilters, testFilterWrongType)) require.EqualError(t, err, "unknown parameter type") - require.Equal(t, uint(0), res) + require.Nil(t, res) + // keep subscriptions len == 3, no additions was made + require.Len(t, wsconn.Subscriptions, 3) res, err = wsconn.initStorageChangeListener(6, []interface{}{1}) require.EqualError(t, err, "unknown parameter type") - require.Equal(t, uint(0), res) + require.Nil(t, res) + require.Len(t, wsconn.Subscriptions, 3) c.WriteMessage(websocket.TextMessage, []byte(`{ "jsonrpc": "2.0", @@ -164,18 +174,19 @@ func TestWSConn_HandleComm(t *testing.T) { require.Equal(t, []byte(`{"jsonrpc":"2.0","result":true,"id":7}`+"\n"), msg) // test initBlockListener - res, err = wsconn.initBlockListener(1) + res, err = wsconn.initBlockListener(1, nil) require.EqualError(t, err, "error BlockAPI not set") - require.Equal(t, uint(0), res) + require.Nil(t, res) _, msg, err = c.ReadMessage() require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","error":{"code":null,"message":"error BlockAPI not set"},"id":1}`+"\n"), msg) wsconn.BlockAPI = modules.NewMockBlockAPI() - res, err = wsconn.initBlockListener(1) + res, err = wsconn.initBlockListener(1, nil) require.NoError(t, err) - require.Equal(t, uint(5), res) + require.NotNil(t, res) + require.Len(t, wsconn.Subscriptions, 5) _, msg, err = c.ReadMessage() require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","result":5,"id":1}`+"\n"), msg) @@ -193,18 +204,19 @@ func TestWSConn_HandleComm(t *testing.T) { // test initBlockFinalizedListener wsconn.BlockAPI = nil - res, err = wsconn.initBlockFinalizedListener(1) + res, err = wsconn.initBlockFinalizedListener(1, nil) require.EqualError(t, err, "error BlockAPI not set") - require.Equal(t, uint(0), res) + require.Nil(t, res) _, msg, err = c.ReadMessage() require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","error":{"code":null,"message":"error BlockAPI not set"},"id":1}`+"\n"), msg) wsconn.BlockAPI = modules.NewMockBlockAPI() - res, err = wsconn.initBlockFinalizedListener(1) + res, err = wsconn.initBlockFinalizedListener(1, nil) require.NoError(t, err) - require.Equal(t, uint(7), res) + require.NotNil(t, res) + require.Len(t, wsconn.Subscriptions, 7) _, msg, err = c.ReadMessage() require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","result":7,"id":1}`+"\n"), msg) @@ -214,15 +226,16 @@ func TestWSConn_HandleComm(t *testing.T) { wsconn.BlockAPI = nil res, err = wsconn.initExtrinsicWatch(0, []interface{}{"NotHex"}) require.EqualError(t, err, "could not byteify non 0x prefixed string") - require.Equal(t, uint(0), res) + require.Nil(t, res) res, err = wsconn.initExtrinsicWatch(0, []interface{}{"0x26aa"}) require.EqualError(t, err, "error BlockAPI not set") - require.Equal(t, uint(0), res) + require.Nil(t, res) wsconn.BlockAPI = modules.NewMockBlockAPI() res, err = wsconn.initExtrinsicWatch(0, []interface{}{"0x26aa"}) require.NoError(t, err) - require.Equal(t, uint(8), res) + require.NotNil(t, res) + require.Len(t, wsconn.Subscriptions, 8) }