diff --git a/dot/rpc/http.go b/dot/rpc/http.go index 82490dd561..2efce0136e 100644 --- a/dot/rpc/http.go +++ b/dot/rpc/http.go @@ -20,14 +20,13 @@ import ( "fmt" "net/http" "os" + "sync" "github.com/ChainSafe/gossamer/dot/rpc/modules" - "github.com/ChainSafe/gossamer/dot/types" + log "github.com/ChainSafe/log15" "github.com/gorilla/mux" "github.com/gorilla/rpc/v2" "github.com/gorilla/websocket" - - log "github.com/ChainSafe/log15" ) // HTTPServer gateway for RPC server @@ -35,8 +34,7 @@ type HTTPServer struct { logger log.Logger rpcServer *rpc.Server // Actual RPC call handler serverConfig *HTTPServerConfig - blockChan chan *types.Block - chanID byte // channel ID + wsConns []*WSConn } // HTTPServerConfig configures the HTTPServer @@ -56,18 +54,25 @@ type HTTPServerConfig struct { WSEnabled bool WSPort uint32 Modules []string - WSSubscriptions map[uint32]*WebSocketSubscription } -// WebSocketSubscription holds subscription details -type WebSocketSubscription struct { - WSConnection *websocket.Conn - SubscriptionType int +// WSConn struct to hold WebSocket Connection references +type WSConn struct { + wsconn *websocket.Conn + mu sync.Mutex + blockSubChannels map[int]byte + storageSubChannels map[int]byte + qtyListeners int + subscriptions map[int]Listener + storageAPI modules.StorageAPI + blockAPI modules.BlockAPI } +var logger log.Logger + // NewHTTPServer creates a new http server and registers an associated rpc server func NewHTTPServer(cfg *HTTPServerConfig) *HTTPServer { - logger := log.New("pkg", "rpc") + logger = log.New("pkg", "rpc") h := log.StreamHandler(os.Stdout, log.TerminalFormat()) logger.SetHandler(log.LvlFilterHandler(cfg.LogLvl, h)) @@ -77,10 +82,6 @@ func NewHTTPServer(cfg *HTTPServerConfig) *HTTPServer { serverConfig: cfg, } - if cfg.WSSubscriptions == nil { - cfg.WSSubscriptions = make(map[uint32]*WebSocketSubscription) - } - server.RegisterModules(cfg.Modules) return server } @@ -151,25 +152,30 @@ func (h *HTTPServer) Start() error { } }() - // init and start block received listener routine - if h.serverConfig.BlockAPI != nil { - var err error - h.blockChan = make(chan *types.Block) - h.chanID, err = h.serverConfig.BlockAPI.RegisterImportedChannel(h.blockChan) - if err != nil { - return err - } - go h.blockReceivedListener() - } - return nil } // Stop stops the server func (h *HTTPServer) Stop() error { if h.serverConfig.WSEnabled { - h.serverConfig.BlockAPI.UnregisterImportedChannel(h.chanID) - close(h.blockChan) + // close all channels and websocket connections + for _, conn := range h.wsConns { + for _, sub := range conn.subscriptions { + switch v := sub.(type) { + case *StorageChangeListener: + h.serverConfig.StorageAPI.UnregisterStorageChangeChannel(v.chanID) + close(v.channel) + case *BlockListener: + h.serverConfig.BlockAPI.UnregisterImportedChannel(v.chanID) + close(v.channel) + } + } + + err := conn.wsconn.Close() + if err != nil { + h.logger.Error("error closing websocket connection", "error", err) + } + } } return nil } diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index a59efb70d9..0e8b0aa6e9 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -3,6 +3,7 @@ package modules import ( "math/big" + "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto" @@ -14,6 +15,8 @@ import ( type StorageAPI interface { GetStorage(key []byte) ([]byte, error) Entries() map[string][]byte + RegisterStorageChangeChannel(ch chan<- *state.KeyValue) (byte, error) + UnregisterStorageChangeChannel(id byte) } // BlockAPI is the interface for the block state diff --git a/dot/rpc/modules/state.go b/dot/rpc/modules/state.go index 04e7659171..3d322f107b 100644 --- a/dot/rpc/modules/state.go +++ b/dot/rpc/modules/state.go @@ -285,9 +285,11 @@ func (sm *StateModule) SubscribeRuntimeVersion(r *http.Request, req *StateStorag return sm.GetRuntimeVersion(r, nil, res) } -// SubscribeStorage isn't implemented properly yet. -func (sm *StateModule) SubscribeStorage(r *http.Request, req *StateStorageQueryRangeRequest, res *StorageChangeSetResponse) { - // TODO implement change storage trie so that block hash parameter works (See issue #834) +// SubscribeStorage Storage subscription. If storage keys are specified, it creates a message for each block which +// changes the specified storage keys. If none are specified, then it creates a message for every block. +// This endpoint communicates over the Websocket protocol, but this func should remain here so it's added to rpc_methods list +func (sm *StateModule) SubscribeStorage(r *http.Request, req *StateStorageQueryRangeRequest, res *StorageChangeSetResponse) error { + return nil } func convertAPIs(in []*runtime.API_Item) []interface{} { diff --git a/dot/rpc/websocket.go b/dot/rpc/websocket.go index 2b2cefa17c..ff81da2a0e 100644 --- a/dot/rpc/websocket.go +++ b/dot/rpc/websocket.go @@ -25,44 +25,39 @@ import ( "strings" "github.com/ChainSafe/gossamer/dot/rpc/modules" - - "github.com/ethereum/go-ethereum/log" + "github.com/ChainSafe/gossamer/dot/state" + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" "github.com/gorilla/websocket" ) -// consts to represent subscription type -const ( - SUB_NEW_HEAD = iota - SUB_FINALIZED_HEAD - SUB_STORAGE -) - // SubscriptionBaseResponseJSON for base json response type SubscriptionBaseResponseJSON struct { Jsonrpc string `json:"jsonrpc"` Method string `json:"method"` Params interface{} `json:"params"` - Subscription uint32 `json:"subscription"` + Subscription int `json:"subscription"` } -func newSubcriptionBaseResponseJSON(sub uint32) SubscriptionBaseResponseJSON { +func newSubcriptionBaseResponseJSON(subID int) SubscriptionBaseResponseJSON { return SubscriptionBaseResponseJSON{ Jsonrpc: "2.0", - Subscription: sub, + Subscription: subID, } } // SubscriptionResponseJSON for json subscription responses type SubscriptionResponseJSON struct { Jsonrpc string `json:"jsonrpc"` - Result uint32 `json:"result"` + Result int `json:"result"` ID float64 `json:"id"` } -func newSubscriptionResponseJSON() SubscriptionResponseJSON { +func newSubscriptionResponseJSON(subID int, reqID float64) SubscriptionResponseJSON { return SubscriptionResponseJSON{ Jsonrpc: "2.0", - Result: 0, + Result: subID, + ID: reqID, } } @@ -70,7 +65,7 @@ func newSubscriptionResponseJSON() SubscriptionResponseJSON { type ErrorResponseJSON struct { Jsonrpc string `json:"jsonrpc"` Error *ErrorMessageJSON `json:"error"` - ID *big.Int `json:"id"` + ID float64 `json:"id"` } // ErrorMessageJSON json for error messages @@ -79,80 +74,120 @@ type ErrorMessageJSON struct { Message string `json:"message"` } +var rpcHost string + // ServeHTTP implemented to handle WebSocket connections func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { var upg = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { - return true + return true // todo determine how this should check orgigin }, } + ws, err := upg.Upgrade(w, r, nil) if err != nil { - log.Error("[rpc] websocket upgrade failed", "error", err) + h.logger.Error("websocket upgrade failed", "error", err) return } + // create wsConn + wsc := NewWSConn(ws, h.serverConfig) + h.wsConns = append(h.wsConns, wsc) - rpcHost := fmt.Sprintf("http://%s:%d/", h.serverConfig.Host, h.serverConfig.RPCPort) + go wsc.handleComm() +} + +// NewWSConn to create new WebSocket Connection struct +func NewWSConn(conn *websocket.Conn, cfg *HTTPServerConfig) *WSConn { + rpcHost = fmt.Sprintf("http://%s:%d/", cfg.Host, cfg.RPCPort) + c := &WSConn{ + wsconn: conn, + subscriptions: make(map[int]Listener), + blockSubChannels: make(map[int]byte), + storageSubChannels: make(map[int]byte), + storageAPI: cfg.StorageAPI, + blockAPI: cfg.BlockAPI, + } + return c +} + +func (c *WSConn) safeSend(msg interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + + return c.wsconn.WriteJSON(msg) +} +func (c *WSConn) safeSendError(reqID float64, errorCode *big.Int, message string) error { + res := &ErrorResponseJSON{ + Jsonrpc: "2.0", + Error: &ErrorMessageJSON{ + Code: errorCode, + Message: message, + }, + ID: reqID, + } + c.mu.Lock() + defer c.mu.Unlock() + + return c.wsconn.WriteJSON(res) +} + +func (c *WSConn) handleComm() { for { - _, mbytes, err := ws.ReadMessage() + _, mbytes, err := c.wsconn.ReadMessage() if err != nil { - log.Error("[rpc] websocket failed to read message", "error", err) + logger.Warn("websocket failed to read message", "error", err) return } - log.Trace("[rpc] websocket received", "message", fmt.Sprintf("%s", mbytes)) + logger.Debug("websocket received", "message", fmt.Sprintf("%s", mbytes)) // determine if request is for subscribe method type var msg map[string]interface{} err = json.Unmarshal(mbytes, &msg) if err != nil { - log.Error("[rpc] websocket failed to unmarshal request message", "error", err) - res := &ErrorResponseJSON{ - Jsonrpc: "2.0", - Error: &ErrorMessageJSON{ - Code: big.NewInt(-32600), - Message: "Invalid request", - }, - ID: nil, - } - err = ws.WriteJSON(res) + logger.Warn("websocket failed to unmarshal request message", "error", err) + err = c.safeSendError(0, big.NewInt(-32600), "Invalid request") if err != nil { - log.Error("[rpc] websocket failed write message", "error", err) + logger.Warn("websocket failed write message", "error", err) } continue } method := msg["method"] // if method contains subscribe, then register subscription if strings.Contains(fmt.Sprintf("%s", method), "subscribe") { - mid := msg["id"].(float64) - var subType int + reqid := msg["id"].(float64) + params := msg["params"] switch method { case "chain_subscribeNewHeads", "chain_subscribeNewHead": - subType = SUB_NEW_HEAD - case "chain_subscribeStorage": - subType = SUB_STORAGE + bl, err1 := c.initBlockListener(reqid) + if err1 != nil { + logger.Warn("failed to create block listener", "error", err) + continue + } + c.startListener(bl) + case "state_subscribeStorage": + scl, err2 := c.initStorageChangeListener(reqid, params) + if err2 != nil { + logger.Warn("failed to create state change listener", "error", err) + continue + } + c.startListener(scl) case "chain_subscribeFinalizedHeads": - subType = SUB_FINALIZED_HEAD - } - - var e1 error - _, e1 = h.registerSubscription(ws, mid, subType) - if e1 != nil { - log.Error("[rpc] failed to register subscription", "error", err) } continue } + // handle non-subscribe calls client := &http.Client{} buf := &bytes.Buffer{} _, err = buf.Write(mbytes) if err != nil { - log.Error("[rpc] failed to write message to buffer", "error", err) + logger.Warn("failed to write message to buffer", "error", err) return } req, err := http.NewRequest("POST", rpcHost, buf) if err != nil { - log.Error("[rpc] failed request to rpc service", "error", err) + logger.Warn("failed request to rpc service", "error", err) return } @@ -160,76 +195,169 @@ func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { res, err := client.Do(req) if err != nil { - log.Error("[rpc] websocket error calling rpc", "error", err) + logger.Warn("websocket error calling rpc", "error", err) return } body, err := ioutil.ReadAll(res.Body) if err != nil { - log.Error("[rpc] error reading response body", "error", err) + logger.Warn("error reading response body", "error", err) return } err = res.Body.Close() if err != nil { - log.Error("[rpc] error closing response body", "error", err) + logger.Warn("error closing response body", "error", err) return } var wsSend interface{} err = json.Unmarshal(body, &wsSend) if err != nil { - log.Error("[rpc] error unmarshal rpc response", "error", err) + logger.Warn("error unmarshal rpc response", "error", err) return } - err = ws.WriteJSON(wsSend) + err = c.safeSend(wsSend) if err != nil { - log.Error("[rpc] error writing json response", "error", err) + logger.Warn("error writing json response", "error", err) return } } +} +func (c *WSConn) startListener(lid int) { + go c.subscriptions[lid].Listen() +} + +// Listener interface for functions that define Listener related functions +type Listener interface { + Listen() +} + +// StorageChangeListener for listening to state change channels +type StorageChangeListener struct { + channel chan *state.KeyValue + filter map[string]bool + wsconn *WSConn + chanID byte + subID int +} + +func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (int, error) { + scl := &StorageChangeListener{ + channel: make(chan *state.KeyValue), + filter: make(map[string]bool), + wsconn: c, + } + pA := params.([]interface{}) + for _, param := range pA { + scl.filter[param.(string)] = true + } + if c.storageAPI == nil { + err := c.safeSendError(reqID, nil, "error StorageAPI not set") + if err != nil { + logger.Warn("error sending error message", "error", err) + } + return 0, fmt.Errorf("error StorageAPI not set") + } + chanID, err := c.storageAPI.RegisterStorageChangeChannel(scl.channel) + if err != nil { + return 0, err + } + scl.chanID = chanID + + c.qtyListeners++ + scl.subID = c.qtyListeners + c.subscriptions[scl.subID] = scl + c.storageSubChannels[scl.subID] = chanID + initRes := newSubscriptionResponseJSON(scl.subID, reqID) + err = c.safeSend(initRes) + if err != nil { + return 0, err + } + return scl.subID, nil } -func (h *HTTPServer) registerSubscription(conn *websocket.Conn, reqID float64, subscriptionType int) (uint32, error) { - wssub := h.serverConfig.WSSubscriptions - sub := uint32(len(wssub)) + 1 - wss := &WebSocketSubscription{ - WSConnection: conn, - SubscriptionType: subscriptionType, +// Listen implementation of Listen interface to listen for channel changes +func (l *StorageChangeListener) Listen() { + for change := range l.channel { + if change == nil { + continue + } + + //check if change key is in subscription filter + cKey := common.BytesToHex(change.Key) + if len(l.filter) > 0 && !l.filter[cKey] { + continue + } + + changeM := make(map[string]interface{}) + changeM["result"] = []string{cKey, common.BytesToHex(change.Value)} + res := newSubcriptionBaseResponseJSON(l.subID) + res.Method = "state_storage" + res.Params = changeM + err := l.wsconn.safeSend(res) + if err != nil { + logger.Error("error sending websocket message", "error", err) + } + } - wssub[sub] = wss - h.serverConfig.WSSubscriptions = wssub - initRes := newSubscriptionResponseJSON() - initRes.Result = sub - initRes.ID = reqID +} - return sub, conn.WriteJSON(initRes) +// BlockListener to handle listening for blocks channel +type BlockListener struct { + channel chan *types.Block + wsconn *WSConn + chanID byte + subID int } -func (h *HTTPServer) blockReceivedListener() { - if h.serverConfig.BlockAPI == nil { - return +func (c *WSConn) initBlockListener(reqID float64) (int, error) { + bl := &BlockListener{ + channel: make(chan *types.Block), + wsconn: c, } - for block := range h.blockChan { - if block != nil { - for i, sub := range h.serverConfig.WSSubscriptions { - if sub.SubscriptionType == SUB_NEW_HEAD { - head := modules.HeaderToJSON(*block.Header) - headM := make(map[string]interface{}) - headM["result"] = head - res := newSubcriptionBaseResponseJSON(i) - res.Method = "chain_newHead" - res.Params = headM - if sub.WSConnection != nil { - err := sub.WSConnection.WriteJSON(res) - if err != nil { - log.Error("[rpc] error writing response", "error", err) - } - } - } - } + if c.blockAPI == nil { + err := c.safeSendError(reqID, nil, "error BlockAPI not set") + if err != nil { + logger.Warn("error sending error message", "error", err) + } + return 0, fmt.Errorf("error BlockAPI not set") + } + chanID, err := c.blockAPI.RegisterImportedChannel(bl.channel) + if err != nil { + return 0, err + } + bl.chanID = chanID + c.qtyListeners++ + bl.subID = c.qtyListeners + c.subscriptions[bl.subID] = bl + c.blockSubChannels[bl.subID] = chanID + initRes := newSubscriptionResponseJSON(bl.subID, reqID) + err = c.safeSend(initRes) + if err != nil { + return 0, err + } + return bl.subID, nil +} + +// Listen implementation of Listen interface to listen for channel changes +func (l *BlockListener) Listen() { + for block := range l.channel { + if block == nil { + continue + } + head := modules.HeaderToJSON(*block.Header) + headM := make(map[string]interface{}) + headM["result"] = head + res := newSubcriptionBaseResponseJSON(l.subID) + res.Method = "chain_newHead" + res.Params = headM + err := l.wsconn.safeSend(res) + if err != nil { + logger.Error("error sending websocket message", "error", err) } + } } diff --git a/dot/rpc/websocket_test.go b/dot/rpc/websocket_test.go index b0ae6e17b4..3624f3333f 100644 --- a/dot/rpc/websocket_test.go +++ b/dot/rpc/websocket_test.go @@ -3,13 +3,16 @@ package rpc import ( "flag" "log" + "math/big" "net/url" "testing" "time" "github.com/ChainSafe/gossamer/dot/core" + "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/system" "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" "github.com/gorilla/websocket" "github.com/stretchr/testify/require" ) @@ -21,24 +24,29 @@ var testCalls = []struct { }{ {[]byte(`{"jsonrpc":"2.0","method":"system_name","params":[],"id":1}`), []byte(`{"id":1,"jsonrpc":"2.0","result":"gossamer"}` + "\n")}, // working request {[]byte(`{"jsonrpc":"2.0","method":"unknown","params":[],"id":1}`), []byte(`{"error":{"code":-32000,"data":null,"message":"rpc error method unknown not found"},"id":1,"jsonrpc":"2.0"}` + "\n")}, // unknown method - {[]byte{}, []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":null}` + "\n")}, // empty request - {[]byte(`{"jsonrpc":"2.0","method":"chain_subscribeNewHeads","params":[],"id":1}`), []byte(`{"jsonrpc":"2.0","result":1,"id":1}` + "\n")}, + {[]byte{}, []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n")}, // empty request + {[]byte(`{"jsonrpc":"2.0","method":"chain_subscribeNewHeads","params":[],"id":3}`), []byte(`{"jsonrpc":"2.0","result":1,"id":3}` + "\n")}, + {[]byte(`{"jsonrpc":"2.0","method":"state_subscribeStorage","params":[],"id":4}`), []byte(`{"jsonrpc":"2.0","result":2,"id":4}` + "\n")}, } -func TestNewWebSocketServer(t *testing.T) { +func TestHTTPServer_ServeHTTP(t *testing.T) { coreAPI := core.NewTestService(t, nil) si := &types.SystemInfo{ SystemName: "gossamer", } sysAPI := system.NewService(si) + bAPI := new(MockBlockAPI) + sAPI := new(MockStorageAPI) cfg := &HTTPServerConfig{ - Modules: []string{"system", "chain"}, - RPCPort: 8545, - WSPort: 8546, - WSEnabled: true, - RPCAPI: NewService(), - CoreAPI: coreAPI, - SystemAPI: sysAPI, + Modules: []string{"system", "chain"}, + RPCPort: 8545, + WSPort: 8546, + WSEnabled: true, + RPCAPI: NewService(), + CoreAPI: coreAPI, + SystemAPI: sysAPI, + BlockAPI: bAPI, + StorageAPI: sAPI, } s := NewHTTPServer(cfg) @@ -65,3 +73,42 @@ func TestNewWebSocketServer(t *testing.T) { require.Equal(t, item.expected, message) } } + +type MockBlockAPI struct { +} + +func (m *MockBlockAPI) GetHeader(hash common.Hash) (*types.Header, error) { + return nil, nil +} +func (m *MockBlockAPI) HighestBlockHash() common.Hash { + return common.Hash{} +} +func (m *MockBlockAPI) GetBlockByHash(hash common.Hash) (*types.Block, error) { + return nil, nil +} +func (m *MockBlockAPI) GetBlockHash(blockNumber *big.Int) (*common.Hash, error) { + return nil, nil +} +func (m *MockBlockAPI) GetFinalizedHash(uint64) (common.Hash, error) { + return common.Hash{}, nil +} +func (m *MockBlockAPI) RegisterImportedChannel(ch chan<- *types.Block) (byte, error) { + return 0, nil +} +func (m *MockBlockAPI) UnregisterImportedChannel(id byte) { +} + +type MockStorageAPI struct{} + +func (m *MockStorageAPI) GetStorage(key []byte) ([]byte, error) { + return nil, nil +} +func (m *MockStorageAPI) Entries() map[string][]byte { + return nil +} +func (m *MockStorageAPI) RegisterStorageChangeChannel(ch chan<- *state.KeyValue) (byte, error) { + return 0, nil +} +func (m *MockStorageAPI) UnregisterStorageChangeChannel(id byte) { + +} diff --git a/dot/services_test.go b/dot/services_test.go index 9d77413c22..4182cd721c 100644 --- a/dot/services_test.go +++ b/dot/services_test.go @@ -17,12 +17,16 @@ package dot import ( + "flag" "math/big" + "net/url" "testing" + "time" "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/lib/keystore" "github.com/ChainSafe/gossamer/lib/utils" + "github.com/gorilla/websocket" "github.com/stretchr/testify/require" ) @@ -229,3 +233,75 @@ func TestCreateGrandpaService(t *testing.T) { require.NoError(t, err) require.NotNil(t, gs) } + +var addr = flag.String("addr", "localhost:8546", "http service address") +var testCalls = []struct { + call []byte + expected []byte +}{ + {[]byte(`{"jsonrpc":"2.0","method":"system_name","params":[],"id":1}`), []byte(`{"id":1,"jsonrpc":"2.0","result":"gossamer"}` + "\n")}, // working request + {[]byte(`{"jsonrpc":"2.0","method":"unknown","params":[],"id":2}`), []byte(`{"error":{"code":-32000,"data":null,"message":"rpc error method unknown not found"},"id":2,"jsonrpc":"2.0"}` + "\n")}, // unknown method + {[]byte{}, []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n")}, // empty request + {[]byte(`{"jsonrpc":"2.0","method":"chain_subscribeNewHeads","params":[],"id":3}`), []byte(`{"jsonrpc":"2.0","result":1,"id":3}` + "\n")}, + {[]byte(`{"jsonrpc":"2.0","method":"state_subscribeStorage","params":[],"id":4}`), []byte(`{"jsonrpc":"2.0","result":2,"id":4}` + "\n")}, +} + +func TestNewWebSocketServer(t *testing.T) { + cfg := NewTestConfig(t) + require.NotNil(t, cfg) + + genFile := NewTestGenesisFile(t, cfg) + require.NotNil(t, genFile) + + defer utils.RemoveTestDir(t) + + cfg.Core.Authority = false + cfg.Core.BabeAuthority = false + cfg.Core.GrandpaAuthority = false + cfg.Init.Genesis = genFile.Name() + cfg.RPC.WSEnabled = true + cfg.System.SystemName = "gossamer" + + err := InitNode(cfg) + require.Nil(t, err) + + stateSrvc, err := createStateService(cfg) + require.Nil(t, err) + + coreMsgs := make(chan network.Message) + networkMsgs := make(chan network.Message) + + ks := keystore.NewKeystore() + rt, err := createRuntime(cfg, stateSrvc, ks) + require.NoError(t, err) + + coreSrvc, err := createCoreService(cfg, nil, nil, rt, ks, stateSrvc, coreMsgs, networkMsgs, make(chan *big.Int)) + require.Nil(t, err) + + networkSrvc := &network.Service{} + + sysSrvc := createSystemService(&cfg.System) + + rpcSrvc, err := createRPCService(cfg, stateSrvc, coreSrvc, networkSrvc, nil, rt, sysSrvc) + require.Nil(t, err) + + err = rpcSrvc.Start() + require.Nil(t, err) + + time.Sleep(time.Second) // give server a second to start + + u := url.URL{Scheme: "ws", Host: *addr, Path: "/"} + + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + require.NoError(t, err) + defer c.Close() + + for _, item := range testCalls { + err = c.WriteMessage(websocket.TextMessage, item.call) + require.Nil(t, err) + + _, message, err := c.ReadMessage() + require.Nil(t, err) + require.Equal(t, item.expected, message) + } +} diff --git a/dot/state/storage.go b/dot/state/storage.go index 10a123c1be..6cb121fd9c 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -52,6 +52,10 @@ type StorageState struct { trie *trie.Trie db *StorageDB lock sync.RWMutex + + // change notifiers + changed map[byte]chan<- *KeyValue + changedLock sync.RWMutex } // NewStorageDB instantiates badgerDB instance for storing trie structure @@ -72,8 +76,9 @@ func NewStorageState(db chaindb.Database, t *trie.Trie) (*StorageState, error) { } return &StorageState{ - trie: t, - db: NewStorageDB(db), + trie: t, + db: NewStorageDB(db), + changed: make(map[byte]chan<- *KeyValue), }, nil } @@ -120,7 +125,16 @@ func (s *StorageState) EnumeratedTrieRoot(values [][]byte) { func (s *StorageState) SetStorage(key []byte, value []byte) error { s.lock.Lock() defer s.lock.Unlock() - return s.trie.Put(key, value) + kv := &KeyValue{ + Key: key, + Value: value, + } + err := s.trie.Put(key, value) + if err != nil { + return err + } + s.notifyChanged(kv) + return nil } // ClearPrefix not implemented @@ -133,7 +147,16 @@ func (s *StorageState) ClearPrefix(prefix []byte) { func (s *StorageState) ClearStorage(key []byte) error { s.lock.Lock() defer s.lock.Unlock() - return s.trie.Delete(key) + kv := &KeyValue{ + Key: key, + Value: nil, + } + err := s.trie.Delete(key) + if err != nil { + return err + } + s.notifyChanged(kv) + return nil } // Entries returns Entries from the trie diff --git a/dot/state/storage_notify.go b/dot/state/storage_notify.go new file mode 100644 index 0000000000..cd49a9e911 --- /dev/null +++ b/dot/state/storage_notify.go @@ -0,0 +1,76 @@ +// Copyright 2020 ChainSafe Systems (ON) Corp. +// This file is part of gossamer. +// +// The gossamer library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The gossamer library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the gossamer library. If not, see . +package state + +import ( + "errors" +) + +// KeyValue struct to hold key value pairs +type KeyValue struct { + Key []byte + Value []byte +} + +// RegisterStorageChangeChannel function to register storage change channels +func (s *StorageState) RegisterStorageChangeChannel(ch chan<- *KeyValue) (byte, error) { + s.changedLock.RLock() + + if len(s.changed) == 256 { + return 0, errors.New("channel limit reached") + } + + var id byte + for { + id = generateID() + if s.changed[id] == nil { + break + } + } + + s.changedLock.RUnlock() + + s.changedLock.Lock() + s.changed[id] = ch + s.changedLock.Unlock() + return id, nil +} + +// UnregisterStorageChangeChannel removes the storage change notification channel with the given ID. +// A channel must be unregistered before closing it. +func (s *StorageState) UnregisterStorageChangeChannel(id byte) { + s.changedLock.Lock() + defer s.changedLock.Unlock() + + delete(s.changed, id) +} + +func (s *StorageState) notifyChanged(change *KeyValue) { + s.changedLock.RLock() + defer s.changedLock.RUnlock() + + if len(s.changed) == 0 { + return + } + + logger.Trace("notifying changed storage chans...", "chans", s.changed) + + for _, ch := range s.changed { + go func(ch chan<- *KeyValue) { + ch <- change + }(ch) + } +} diff --git a/dot/state/storage_notify_test.go b/dot/state/storage_notify_test.go new file mode 100644 index 0000000000..02cd9a0f3d --- /dev/null +++ b/dot/state/storage_notify_test.go @@ -0,0 +1,88 @@ +// Copyright 2020 ChainSafe Systems (ON) Corp. +// This file is part of gossamer. +// +// The gossamer library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The gossamer library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the gossamer library. If not, see . +package state + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestStorageState_RegisterStorageChangeChannel(t *testing.T) { + ss := newTestStorageState(t) + + ch := make(chan *KeyValue, 3) + id, err := ss.RegisterStorageChangeChannel(ch) + require.NoError(t, err) + + defer ss.UnregisterStorageChangeChannel(id) + + // three storage change events + ss.SetStorage([]byte("mackcom"), []byte("wuz here")) + ss.SetStorage([]byte("key1"), []byte("value1")) + ss.SetStorage([]byte("key1"), []byte("value2")) + + for i := 0; i < 3; i++ { + select { + case <-ch: + case <-time.After(testMessageTimeout): + t.Fatal("did not receive storage change message") + } + } +} + +func TestStorageState_RegisterStorageChangeChannel_Multi(t *testing.T) { + ss := newTestStorageState(t) + + num := 5 + chs := make([]chan *KeyValue, num) + ids := make([]byte, num) + + var err error + for i := 0; i < num; i++ { + chs[i] = make(chan *KeyValue) + ids[i], err = ss.RegisterStorageChangeChannel(chs[i]) + require.NoError(t, err) + } + + key1 := []byte("key1") + ss.SetStorage(key1, []byte("value1")) + + var wg sync.WaitGroup + wg.Add(num) + + for i, ch := range chs { + + go func(i int, ch chan *KeyValue) { + select { + case c := <-ch: + require.Equal(t, key1, c.Key) + wg.Done() + case <-time.After(testMessageTimeout): + t.Error("did not receive storage change: ch=", i) + } + }(i, ch) + + } + + wg.Wait() + + for _, id := range ids { + ss.UnregisterStorageChangeChannel(id) + } +}