diff --git a/dot/rpc/http.go b/dot/rpc/http.go
index 82490dd561..2efce0136e 100644
--- a/dot/rpc/http.go
+++ b/dot/rpc/http.go
@@ -20,14 +20,13 @@ import (
"fmt"
"net/http"
"os"
+ "sync"
"github.com/ChainSafe/gossamer/dot/rpc/modules"
- "github.com/ChainSafe/gossamer/dot/types"
+ log "github.com/ChainSafe/log15"
"github.com/gorilla/mux"
"github.com/gorilla/rpc/v2"
"github.com/gorilla/websocket"
-
- log "github.com/ChainSafe/log15"
)
// HTTPServer gateway for RPC server
@@ -35,8 +34,7 @@ type HTTPServer struct {
logger log.Logger
rpcServer *rpc.Server // Actual RPC call handler
serverConfig *HTTPServerConfig
- blockChan chan *types.Block
- chanID byte // channel ID
+ wsConns []*WSConn
}
// HTTPServerConfig configures the HTTPServer
@@ -56,18 +54,25 @@ type HTTPServerConfig struct {
WSEnabled bool
WSPort uint32
Modules []string
- WSSubscriptions map[uint32]*WebSocketSubscription
}
-// WebSocketSubscription holds subscription details
-type WebSocketSubscription struct {
- WSConnection *websocket.Conn
- SubscriptionType int
+// WSConn struct to hold WebSocket Connection references
+type WSConn struct {
+ wsconn *websocket.Conn
+ mu sync.Mutex
+ blockSubChannels map[int]byte
+ storageSubChannels map[int]byte
+ qtyListeners int
+ subscriptions map[int]Listener
+ storageAPI modules.StorageAPI
+ blockAPI modules.BlockAPI
}
+var logger log.Logger
+
// NewHTTPServer creates a new http server and registers an associated rpc server
func NewHTTPServer(cfg *HTTPServerConfig) *HTTPServer {
- logger := log.New("pkg", "rpc")
+ logger = log.New("pkg", "rpc")
h := log.StreamHandler(os.Stdout, log.TerminalFormat())
logger.SetHandler(log.LvlFilterHandler(cfg.LogLvl, h))
@@ -77,10 +82,6 @@ func NewHTTPServer(cfg *HTTPServerConfig) *HTTPServer {
serverConfig: cfg,
}
- if cfg.WSSubscriptions == nil {
- cfg.WSSubscriptions = make(map[uint32]*WebSocketSubscription)
- }
-
server.RegisterModules(cfg.Modules)
return server
}
@@ -151,25 +152,30 @@ func (h *HTTPServer) Start() error {
}
}()
- // init and start block received listener routine
- if h.serverConfig.BlockAPI != nil {
- var err error
- h.blockChan = make(chan *types.Block)
- h.chanID, err = h.serverConfig.BlockAPI.RegisterImportedChannel(h.blockChan)
- if err != nil {
- return err
- }
- go h.blockReceivedListener()
- }
-
return nil
}
// Stop stops the server
func (h *HTTPServer) Stop() error {
if h.serverConfig.WSEnabled {
- h.serverConfig.BlockAPI.UnregisterImportedChannel(h.chanID)
- close(h.blockChan)
+ // close all channels and websocket connections
+ for _, conn := range h.wsConns {
+ for _, sub := range conn.subscriptions {
+ switch v := sub.(type) {
+ case *StorageChangeListener:
+ h.serverConfig.StorageAPI.UnregisterStorageChangeChannel(v.chanID)
+ close(v.channel)
+ case *BlockListener:
+ h.serverConfig.BlockAPI.UnregisterImportedChannel(v.chanID)
+ close(v.channel)
+ }
+ }
+
+ err := conn.wsconn.Close()
+ if err != nil {
+ h.logger.Error("error closing websocket connection", "error", err)
+ }
+ }
}
return nil
}
diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go
index a59efb70d9..0e8b0aa6e9 100644
--- a/dot/rpc/modules/api.go
+++ b/dot/rpc/modules/api.go
@@ -3,6 +3,7 @@ package modules
import (
"math/big"
+ "github.com/ChainSafe/gossamer/dot/state"
"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/crypto"
@@ -14,6 +15,8 @@ import (
type StorageAPI interface {
GetStorage(key []byte) ([]byte, error)
Entries() map[string][]byte
+ RegisterStorageChangeChannel(ch chan<- *state.KeyValue) (byte, error)
+ UnregisterStorageChangeChannel(id byte)
}
// BlockAPI is the interface for the block state
diff --git a/dot/rpc/modules/state.go b/dot/rpc/modules/state.go
index 04e7659171..3d322f107b 100644
--- a/dot/rpc/modules/state.go
+++ b/dot/rpc/modules/state.go
@@ -285,9 +285,11 @@ func (sm *StateModule) SubscribeRuntimeVersion(r *http.Request, req *StateStorag
return sm.GetRuntimeVersion(r, nil, res)
}
-// SubscribeStorage isn't implemented properly yet.
-func (sm *StateModule) SubscribeStorage(r *http.Request, req *StateStorageQueryRangeRequest, res *StorageChangeSetResponse) {
- // TODO implement change storage trie so that block hash parameter works (See issue #834)
+// SubscribeStorage Storage subscription. If storage keys are specified, it creates a message for each block which
+// changes the specified storage keys. If none are specified, then it creates a message for every block.
+// This endpoint communicates over the Websocket protocol, but this func should remain here so it's added to rpc_methods list
+func (sm *StateModule) SubscribeStorage(r *http.Request, req *StateStorageQueryRangeRequest, res *StorageChangeSetResponse) error {
+ return nil
}
func convertAPIs(in []*runtime.API_Item) []interface{} {
diff --git a/dot/rpc/websocket.go b/dot/rpc/websocket.go
index 2b2cefa17c..ff81da2a0e 100644
--- a/dot/rpc/websocket.go
+++ b/dot/rpc/websocket.go
@@ -25,44 +25,39 @@ import (
"strings"
"github.com/ChainSafe/gossamer/dot/rpc/modules"
-
- "github.com/ethereum/go-ethereum/log"
+ "github.com/ChainSafe/gossamer/dot/state"
+ "github.com/ChainSafe/gossamer/dot/types"
+ "github.com/ChainSafe/gossamer/lib/common"
"github.com/gorilla/websocket"
)
-// consts to represent subscription type
-const (
- SUB_NEW_HEAD = iota
- SUB_FINALIZED_HEAD
- SUB_STORAGE
-)
-
// SubscriptionBaseResponseJSON for base json response
type SubscriptionBaseResponseJSON struct {
Jsonrpc string `json:"jsonrpc"`
Method string `json:"method"`
Params interface{} `json:"params"`
- Subscription uint32 `json:"subscription"`
+ Subscription int `json:"subscription"`
}
-func newSubcriptionBaseResponseJSON(sub uint32) SubscriptionBaseResponseJSON {
+func newSubcriptionBaseResponseJSON(subID int) SubscriptionBaseResponseJSON {
return SubscriptionBaseResponseJSON{
Jsonrpc: "2.0",
- Subscription: sub,
+ Subscription: subID,
}
}
// SubscriptionResponseJSON for json subscription responses
type SubscriptionResponseJSON struct {
Jsonrpc string `json:"jsonrpc"`
- Result uint32 `json:"result"`
+ Result int `json:"result"`
ID float64 `json:"id"`
}
-func newSubscriptionResponseJSON() SubscriptionResponseJSON {
+func newSubscriptionResponseJSON(subID int, reqID float64) SubscriptionResponseJSON {
return SubscriptionResponseJSON{
Jsonrpc: "2.0",
- Result: 0,
+ Result: subID,
+ ID: reqID,
}
}
@@ -70,7 +65,7 @@ func newSubscriptionResponseJSON() SubscriptionResponseJSON {
type ErrorResponseJSON struct {
Jsonrpc string `json:"jsonrpc"`
Error *ErrorMessageJSON `json:"error"`
- ID *big.Int `json:"id"`
+ ID float64 `json:"id"`
}
// ErrorMessageJSON json for error messages
@@ -79,80 +74,120 @@ type ErrorMessageJSON struct {
Message string `json:"message"`
}
+var rpcHost string
+
// ServeHTTP implemented to handle WebSocket connections
func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var upg = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
- return true
+ return true // todo determine how this should check orgigin
},
}
+
ws, err := upg.Upgrade(w, r, nil)
if err != nil {
- log.Error("[rpc] websocket upgrade failed", "error", err)
+ h.logger.Error("websocket upgrade failed", "error", err)
return
}
+ // create wsConn
+ wsc := NewWSConn(ws, h.serverConfig)
+ h.wsConns = append(h.wsConns, wsc)
- rpcHost := fmt.Sprintf("http://%s:%d/", h.serverConfig.Host, h.serverConfig.RPCPort)
+ go wsc.handleComm()
+}
+
+// NewWSConn to create new WebSocket Connection struct
+func NewWSConn(conn *websocket.Conn, cfg *HTTPServerConfig) *WSConn {
+ rpcHost = fmt.Sprintf("http://%s:%d/", cfg.Host, cfg.RPCPort)
+ c := &WSConn{
+ wsconn: conn,
+ subscriptions: make(map[int]Listener),
+ blockSubChannels: make(map[int]byte),
+ storageSubChannels: make(map[int]byte),
+ storageAPI: cfg.StorageAPI,
+ blockAPI: cfg.BlockAPI,
+ }
+ return c
+}
+
+func (c *WSConn) safeSend(msg interface{}) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ return c.wsconn.WriteJSON(msg)
+}
+func (c *WSConn) safeSendError(reqID float64, errorCode *big.Int, message string) error {
+ res := &ErrorResponseJSON{
+ Jsonrpc: "2.0",
+ Error: &ErrorMessageJSON{
+ Code: errorCode,
+ Message: message,
+ },
+ ID: reqID,
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ return c.wsconn.WriteJSON(res)
+}
+
+func (c *WSConn) handleComm() {
for {
- _, mbytes, err := ws.ReadMessage()
+ _, mbytes, err := c.wsconn.ReadMessage()
if err != nil {
- log.Error("[rpc] websocket failed to read message", "error", err)
+ logger.Warn("websocket failed to read message", "error", err)
return
}
- log.Trace("[rpc] websocket received", "message", fmt.Sprintf("%s", mbytes))
+ logger.Debug("websocket received", "message", fmt.Sprintf("%s", mbytes))
// determine if request is for subscribe method type
var msg map[string]interface{}
err = json.Unmarshal(mbytes, &msg)
if err != nil {
- log.Error("[rpc] websocket failed to unmarshal request message", "error", err)
- res := &ErrorResponseJSON{
- Jsonrpc: "2.0",
- Error: &ErrorMessageJSON{
- Code: big.NewInt(-32600),
- Message: "Invalid request",
- },
- ID: nil,
- }
- err = ws.WriteJSON(res)
+ logger.Warn("websocket failed to unmarshal request message", "error", err)
+ err = c.safeSendError(0, big.NewInt(-32600), "Invalid request")
if err != nil {
- log.Error("[rpc] websocket failed write message", "error", err)
+ logger.Warn("websocket failed write message", "error", err)
}
continue
}
method := msg["method"]
// if method contains subscribe, then register subscription
if strings.Contains(fmt.Sprintf("%s", method), "subscribe") {
- mid := msg["id"].(float64)
- var subType int
+ reqid := msg["id"].(float64)
+ params := msg["params"]
switch method {
case "chain_subscribeNewHeads", "chain_subscribeNewHead":
- subType = SUB_NEW_HEAD
- case "chain_subscribeStorage":
- subType = SUB_STORAGE
+ bl, err1 := c.initBlockListener(reqid)
+ if err1 != nil {
+ logger.Warn("failed to create block listener", "error", err)
+ continue
+ }
+ c.startListener(bl)
+ case "state_subscribeStorage":
+ scl, err2 := c.initStorageChangeListener(reqid, params)
+ if err2 != nil {
+ logger.Warn("failed to create state change listener", "error", err)
+ continue
+ }
+ c.startListener(scl)
case "chain_subscribeFinalizedHeads":
- subType = SUB_FINALIZED_HEAD
- }
-
- var e1 error
- _, e1 = h.registerSubscription(ws, mid, subType)
- if e1 != nil {
- log.Error("[rpc] failed to register subscription", "error", err)
}
continue
}
+ // handle non-subscribe calls
client := &http.Client{}
buf := &bytes.Buffer{}
_, err = buf.Write(mbytes)
if err != nil {
- log.Error("[rpc] failed to write message to buffer", "error", err)
+ logger.Warn("failed to write message to buffer", "error", err)
return
}
req, err := http.NewRequest("POST", rpcHost, buf)
if err != nil {
- log.Error("[rpc] failed request to rpc service", "error", err)
+ logger.Warn("failed request to rpc service", "error", err)
return
}
@@ -160,76 +195,169 @@ func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
res, err := client.Do(req)
if err != nil {
- log.Error("[rpc] websocket error calling rpc", "error", err)
+ logger.Warn("websocket error calling rpc", "error", err)
return
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
- log.Error("[rpc] error reading response body", "error", err)
+ logger.Warn("error reading response body", "error", err)
return
}
err = res.Body.Close()
if err != nil {
- log.Error("[rpc] error closing response body", "error", err)
+ logger.Warn("error closing response body", "error", err)
return
}
var wsSend interface{}
err = json.Unmarshal(body, &wsSend)
if err != nil {
- log.Error("[rpc] error unmarshal rpc response", "error", err)
+ logger.Warn("error unmarshal rpc response", "error", err)
return
}
- err = ws.WriteJSON(wsSend)
+ err = c.safeSend(wsSend)
if err != nil {
- log.Error("[rpc] error writing json response", "error", err)
+ logger.Warn("error writing json response", "error", err)
return
}
}
+}
+func (c *WSConn) startListener(lid int) {
+ go c.subscriptions[lid].Listen()
+}
+
+// Listener interface for functions that define Listener related functions
+type Listener interface {
+ Listen()
+}
+
+// StorageChangeListener for listening to state change channels
+type StorageChangeListener struct {
+ channel chan *state.KeyValue
+ filter map[string]bool
+ wsconn *WSConn
+ chanID byte
+ subID int
+}
+
+func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (int, error) {
+ scl := &StorageChangeListener{
+ channel: make(chan *state.KeyValue),
+ filter: make(map[string]bool),
+ wsconn: c,
+ }
+ pA := params.([]interface{})
+ for _, param := range pA {
+ scl.filter[param.(string)] = true
+ }
+ if c.storageAPI == nil {
+ err := c.safeSendError(reqID, nil, "error StorageAPI not set")
+ if err != nil {
+ logger.Warn("error sending error message", "error", err)
+ }
+ return 0, fmt.Errorf("error StorageAPI not set")
+ }
+ chanID, err := c.storageAPI.RegisterStorageChangeChannel(scl.channel)
+ if err != nil {
+ return 0, err
+ }
+ scl.chanID = chanID
+
+ c.qtyListeners++
+ scl.subID = c.qtyListeners
+ c.subscriptions[scl.subID] = scl
+ c.storageSubChannels[scl.subID] = chanID
+ initRes := newSubscriptionResponseJSON(scl.subID, reqID)
+ err = c.safeSend(initRes)
+ if err != nil {
+ return 0, err
+ }
+ return scl.subID, nil
}
-func (h *HTTPServer) registerSubscription(conn *websocket.Conn, reqID float64, subscriptionType int) (uint32, error) {
- wssub := h.serverConfig.WSSubscriptions
- sub := uint32(len(wssub)) + 1
- wss := &WebSocketSubscription{
- WSConnection: conn,
- SubscriptionType: subscriptionType,
+// Listen implementation of Listen interface to listen for channel changes
+func (l *StorageChangeListener) Listen() {
+ for change := range l.channel {
+ if change == nil {
+ continue
+ }
+
+ //check if change key is in subscription filter
+ cKey := common.BytesToHex(change.Key)
+ if len(l.filter) > 0 && !l.filter[cKey] {
+ continue
+ }
+
+ changeM := make(map[string]interface{})
+ changeM["result"] = []string{cKey, common.BytesToHex(change.Value)}
+ res := newSubcriptionBaseResponseJSON(l.subID)
+ res.Method = "state_storage"
+ res.Params = changeM
+ err := l.wsconn.safeSend(res)
+ if err != nil {
+ logger.Error("error sending websocket message", "error", err)
+ }
+
}
- wssub[sub] = wss
- h.serverConfig.WSSubscriptions = wssub
- initRes := newSubscriptionResponseJSON()
- initRes.Result = sub
- initRes.ID = reqID
+}
- return sub, conn.WriteJSON(initRes)
+// BlockListener to handle listening for blocks channel
+type BlockListener struct {
+ channel chan *types.Block
+ wsconn *WSConn
+ chanID byte
+ subID int
}
-func (h *HTTPServer) blockReceivedListener() {
- if h.serverConfig.BlockAPI == nil {
- return
+func (c *WSConn) initBlockListener(reqID float64) (int, error) {
+ bl := &BlockListener{
+ channel: make(chan *types.Block),
+ wsconn: c,
}
- for block := range h.blockChan {
- if block != nil {
- for i, sub := range h.serverConfig.WSSubscriptions {
- if sub.SubscriptionType == SUB_NEW_HEAD {
- head := modules.HeaderToJSON(*block.Header)
- headM := make(map[string]interface{})
- headM["result"] = head
- res := newSubcriptionBaseResponseJSON(i)
- res.Method = "chain_newHead"
- res.Params = headM
- if sub.WSConnection != nil {
- err := sub.WSConnection.WriteJSON(res)
- if err != nil {
- log.Error("[rpc] error writing response", "error", err)
- }
- }
- }
- }
+ if c.blockAPI == nil {
+ err := c.safeSendError(reqID, nil, "error BlockAPI not set")
+ if err != nil {
+ logger.Warn("error sending error message", "error", err)
+ }
+ return 0, fmt.Errorf("error BlockAPI not set")
+ }
+ chanID, err := c.blockAPI.RegisterImportedChannel(bl.channel)
+ if err != nil {
+ return 0, err
+ }
+ bl.chanID = chanID
+ c.qtyListeners++
+ bl.subID = c.qtyListeners
+ c.subscriptions[bl.subID] = bl
+ c.blockSubChannels[bl.subID] = chanID
+ initRes := newSubscriptionResponseJSON(bl.subID, reqID)
+ err = c.safeSend(initRes)
+ if err != nil {
+ return 0, err
+ }
+ return bl.subID, nil
+}
+
+// Listen implementation of Listen interface to listen for channel changes
+func (l *BlockListener) Listen() {
+ for block := range l.channel {
+ if block == nil {
+ continue
+ }
+ head := modules.HeaderToJSON(*block.Header)
+ headM := make(map[string]interface{})
+ headM["result"] = head
+ res := newSubcriptionBaseResponseJSON(l.subID)
+ res.Method = "chain_newHead"
+ res.Params = headM
+ err := l.wsconn.safeSend(res)
+ if err != nil {
+ logger.Error("error sending websocket message", "error", err)
}
+
}
}
diff --git a/dot/rpc/websocket_test.go b/dot/rpc/websocket_test.go
index b0ae6e17b4..3624f3333f 100644
--- a/dot/rpc/websocket_test.go
+++ b/dot/rpc/websocket_test.go
@@ -3,13 +3,16 @@ package rpc
import (
"flag"
"log"
+ "math/big"
"net/url"
"testing"
"time"
"github.com/ChainSafe/gossamer/dot/core"
+ "github.com/ChainSafe/gossamer/dot/state"
"github.com/ChainSafe/gossamer/dot/system"
"github.com/ChainSafe/gossamer/dot/types"
+ "github.com/ChainSafe/gossamer/lib/common"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)
@@ -21,24 +24,29 @@ var testCalls = []struct {
}{
{[]byte(`{"jsonrpc":"2.0","method":"system_name","params":[],"id":1}`), []byte(`{"id":1,"jsonrpc":"2.0","result":"gossamer"}` + "\n")}, // working request
{[]byte(`{"jsonrpc":"2.0","method":"unknown","params":[],"id":1}`), []byte(`{"error":{"code":-32000,"data":null,"message":"rpc error method unknown not found"},"id":1,"jsonrpc":"2.0"}` + "\n")}, // unknown method
- {[]byte{}, []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":null}` + "\n")}, // empty request
- {[]byte(`{"jsonrpc":"2.0","method":"chain_subscribeNewHeads","params":[],"id":1}`), []byte(`{"jsonrpc":"2.0","result":1,"id":1}` + "\n")},
+ {[]byte{}, []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n")}, // empty request
+ {[]byte(`{"jsonrpc":"2.0","method":"chain_subscribeNewHeads","params":[],"id":3}`), []byte(`{"jsonrpc":"2.0","result":1,"id":3}` + "\n")},
+ {[]byte(`{"jsonrpc":"2.0","method":"state_subscribeStorage","params":[],"id":4}`), []byte(`{"jsonrpc":"2.0","result":2,"id":4}` + "\n")},
}
-func TestNewWebSocketServer(t *testing.T) {
+func TestHTTPServer_ServeHTTP(t *testing.T) {
coreAPI := core.NewTestService(t, nil)
si := &types.SystemInfo{
SystemName: "gossamer",
}
sysAPI := system.NewService(si)
+ bAPI := new(MockBlockAPI)
+ sAPI := new(MockStorageAPI)
cfg := &HTTPServerConfig{
- Modules: []string{"system", "chain"},
- RPCPort: 8545,
- WSPort: 8546,
- WSEnabled: true,
- RPCAPI: NewService(),
- CoreAPI: coreAPI,
- SystemAPI: sysAPI,
+ Modules: []string{"system", "chain"},
+ RPCPort: 8545,
+ WSPort: 8546,
+ WSEnabled: true,
+ RPCAPI: NewService(),
+ CoreAPI: coreAPI,
+ SystemAPI: sysAPI,
+ BlockAPI: bAPI,
+ StorageAPI: sAPI,
}
s := NewHTTPServer(cfg)
@@ -65,3 +73,42 @@ func TestNewWebSocketServer(t *testing.T) {
require.Equal(t, item.expected, message)
}
}
+
+type MockBlockAPI struct {
+}
+
+func (m *MockBlockAPI) GetHeader(hash common.Hash) (*types.Header, error) {
+ return nil, nil
+}
+func (m *MockBlockAPI) HighestBlockHash() common.Hash {
+ return common.Hash{}
+}
+func (m *MockBlockAPI) GetBlockByHash(hash common.Hash) (*types.Block, error) {
+ return nil, nil
+}
+func (m *MockBlockAPI) GetBlockHash(blockNumber *big.Int) (*common.Hash, error) {
+ return nil, nil
+}
+func (m *MockBlockAPI) GetFinalizedHash(uint64) (common.Hash, error) {
+ return common.Hash{}, nil
+}
+func (m *MockBlockAPI) RegisterImportedChannel(ch chan<- *types.Block) (byte, error) {
+ return 0, nil
+}
+func (m *MockBlockAPI) UnregisterImportedChannel(id byte) {
+}
+
+type MockStorageAPI struct{}
+
+func (m *MockStorageAPI) GetStorage(key []byte) ([]byte, error) {
+ return nil, nil
+}
+func (m *MockStorageAPI) Entries() map[string][]byte {
+ return nil
+}
+func (m *MockStorageAPI) RegisterStorageChangeChannel(ch chan<- *state.KeyValue) (byte, error) {
+ return 0, nil
+}
+func (m *MockStorageAPI) UnregisterStorageChangeChannel(id byte) {
+
+}
diff --git a/dot/services_test.go b/dot/services_test.go
index 9d77413c22..4182cd721c 100644
--- a/dot/services_test.go
+++ b/dot/services_test.go
@@ -17,12 +17,16 @@
package dot
import (
+ "flag"
"math/big"
+ "net/url"
"testing"
+ "time"
"github.com/ChainSafe/gossamer/dot/network"
"github.com/ChainSafe/gossamer/lib/keystore"
"github.com/ChainSafe/gossamer/lib/utils"
+ "github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)
@@ -229,3 +233,75 @@ func TestCreateGrandpaService(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, gs)
}
+
+var addr = flag.String("addr", "localhost:8546", "http service address")
+var testCalls = []struct {
+ call []byte
+ expected []byte
+}{
+ {[]byte(`{"jsonrpc":"2.0","method":"system_name","params":[],"id":1}`), []byte(`{"id":1,"jsonrpc":"2.0","result":"gossamer"}` + "\n")}, // working request
+ {[]byte(`{"jsonrpc":"2.0","method":"unknown","params":[],"id":2}`), []byte(`{"error":{"code":-32000,"data":null,"message":"rpc error method unknown not found"},"id":2,"jsonrpc":"2.0"}` + "\n")}, // unknown method
+ {[]byte{}, []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n")}, // empty request
+ {[]byte(`{"jsonrpc":"2.0","method":"chain_subscribeNewHeads","params":[],"id":3}`), []byte(`{"jsonrpc":"2.0","result":1,"id":3}` + "\n")},
+ {[]byte(`{"jsonrpc":"2.0","method":"state_subscribeStorage","params":[],"id":4}`), []byte(`{"jsonrpc":"2.0","result":2,"id":4}` + "\n")},
+}
+
+func TestNewWebSocketServer(t *testing.T) {
+ cfg := NewTestConfig(t)
+ require.NotNil(t, cfg)
+
+ genFile := NewTestGenesisFile(t, cfg)
+ require.NotNil(t, genFile)
+
+ defer utils.RemoveTestDir(t)
+
+ cfg.Core.Authority = false
+ cfg.Core.BabeAuthority = false
+ cfg.Core.GrandpaAuthority = false
+ cfg.Init.Genesis = genFile.Name()
+ cfg.RPC.WSEnabled = true
+ cfg.System.SystemName = "gossamer"
+
+ err := InitNode(cfg)
+ require.Nil(t, err)
+
+ stateSrvc, err := createStateService(cfg)
+ require.Nil(t, err)
+
+ coreMsgs := make(chan network.Message)
+ networkMsgs := make(chan network.Message)
+
+ ks := keystore.NewKeystore()
+ rt, err := createRuntime(cfg, stateSrvc, ks)
+ require.NoError(t, err)
+
+ coreSrvc, err := createCoreService(cfg, nil, nil, rt, ks, stateSrvc, coreMsgs, networkMsgs, make(chan *big.Int))
+ require.Nil(t, err)
+
+ networkSrvc := &network.Service{}
+
+ sysSrvc := createSystemService(&cfg.System)
+
+ rpcSrvc, err := createRPCService(cfg, stateSrvc, coreSrvc, networkSrvc, nil, rt, sysSrvc)
+ require.Nil(t, err)
+
+ err = rpcSrvc.Start()
+ require.Nil(t, err)
+
+ time.Sleep(time.Second) // give server a second to start
+
+ u := url.URL{Scheme: "ws", Host: *addr, Path: "/"}
+
+ c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
+ require.NoError(t, err)
+ defer c.Close()
+
+ for _, item := range testCalls {
+ err = c.WriteMessage(websocket.TextMessage, item.call)
+ require.Nil(t, err)
+
+ _, message, err := c.ReadMessage()
+ require.Nil(t, err)
+ require.Equal(t, item.expected, message)
+ }
+}
diff --git a/dot/state/storage.go b/dot/state/storage.go
index 10a123c1be..6cb121fd9c 100644
--- a/dot/state/storage.go
+++ b/dot/state/storage.go
@@ -52,6 +52,10 @@ type StorageState struct {
trie *trie.Trie
db *StorageDB
lock sync.RWMutex
+
+ // change notifiers
+ changed map[byte]chan<- *KeyValue
+ changedLock sync.RWMutex
}
// NewStorageDB instantiates badgerDB instance for storing trie structure
@@ -72,8 +76,9 @@ func NewStorageState(db chaindb.Database, t *trie.Trie) (*StorageState, error) {
}
return &StorageState{
- trie: t,
- db: NewStorageDB(db),
+ trie: t,
+ db: NewStorageDB(db),
+ changed: make(map[byte]chan<- *KeyValue),
}, nil
}
@@ -120,7 +125,16 @@ func (s *StorageState) EnumeratedTrieRoot(values [][]byte) {
func (s *StorageState) SetStorage(key []byte, value []byte) error {
s.lock.Lock()
defer s.lock.Unlock()
- return s.trie.Put(key, value)
+ kv := &KeyValue{
+ Key: key,
+ Value: value,
+ }
+ err := s.trie.Put(key, value)
+ if err != nil {
+ return err
+ }
+ s.notifyChanged(kv)
+ return nil
}
// ClearPrefix not implemented
@@ -133,7 +147,16 @@ func (s *StorageState) ClearPrefix(prefix []byte) {
func (s *StorageState) ClearStorage(key []byte) error {
s.lock.Lock()
defer s.lock.Unlock()
- return s.trie.Delete(key)
+ kv := &KeyValue{
+ Key: key,
+ Value: nil,
+ }
+ err := s.trie.Delete(key)
+ if err != nil {
+ return err
+ }
+ s.notifyChanged(kv)
+ return nil
}
// Entries returns Entries from the trie
diff --git a/dot/state/storage_notify.go b/dot/state/storage_notify.go
new file mode 100644
index 0000000000..cd49a9e911
--- /dev/null
+++ b/dot/state/storage_notify.go
@@ -0,0 +1,76 @@
+// Copyright 2020 ChainSafe Systems (ON) Corp.
+// This file is part of gossamer.
+//
+// The gossamer library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The gossamer library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the gossamer library. If not, see .
+package state
+
+import (
+ "errors"
+)
+
+// KeyValue struct to hold key value pairs
+type KeyValue struct {
+ Key []byte
+ Value []byte
+}
+
+// RegisterStorageChangeChannel function to register storage change channels
+func (s *StorageState) RegisterStorageChangeChannel(ch chan<- *KeyValue) (byte, error) {
+ s.changedLock.RLock()
+
+ if len(s.changed) == 256 {
+ return 0, errors.New("channel limit reached")
+ }
+
+ var id byte
+ for {
+ id = generateID()
+ if s.changed[id] == nil {
+ break
+ }
+ }
+
+ s.changedLock.RUnlock()
+
+ s.changedLock.Lock()
+ s.changed[id] = ch
+ s.changedLock.Unlock()
+ return id, nil
+}
+
+// UnregisterStorageChangeChannel removes the storage change notification channel with the given ID.
+// A channel must be unregistered before closing it.
+func (s *StorageState) UnregisterStorageChangeChannel(id byte) {
+ s.changedLock.Lock()
+ defer s.changedLock.Unlock()
+
+ delete(s.changed, id)
+}
+
+func (s *StorageState) notifyChanged(change *KeyValue) {
+ s.changedLock.RLock()
+ defer s.changedLock.RUnlock()
+
+ if len(s.changed) == 0 {
+ return
+ }
+
+ logger.Trace("notifying changed storage chans...", "chans", s.changed)
+
+ for _, ch := range s.changed {
+ go func(ch chan<- *KeyValue) {
+ ch <- change
+ }(ch)
+ }
+}
diff --git a/dot/state/storage_notify_test.go b/dot/state/storage_notify_test.go
new file mode 100644
index 0000000000..02cd9a0f3d
--- /dev/null
+++ b/dot/state/storage_notify_test.go
@@ -0,0 +1,88 @@
+// Copyright 2020 ChainSafe Systems (ON) Corp.
+// This file is part of gossamer.
+//
+// The gossamer library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The gossamer library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the gossamer library. If not, see .
+package state
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestStorageState_RegisterStorageChangeChannel(t *testing.T) {
+ ss := newTestStorageState(t)
+
+ ch := make(chan *KeyValue, 3)
+ id, err := ss.RegisterStorageChangeChannel(ch)
+ require.NoError(t, err)
+
+ defer ss.UnregisterStorageChangeChannel(id)
+
+ // three storage change events
+ ss.SetStorage([]byte("mackcom"), []byte("wuz here"))
+ ss.SetStorage([]byte("key1"), []byte("value1"))
+ ss.SetStorage([]byte("key1"), []byte("value2"))
+
+ for i := 0; i < 3; i++ {
+ select {
+ case <-ch:
+ case <-time.After(testMessageTimeout):
+ t.Fatal("did not receive storage change message")
+ }
+ }
+}
+
+func TestStorageState_RegisterStorageChangeChannel_Multi(t *testing.T) {
+ ss := newTestStorageState(t)
+
+ num := 5
+ chs := make([]chan *KeyValue, num)
+ ids := make([]byte, num)
+
+ var err error
+ for i := 0; i < num; i++ {
+ chs[i] = make(chan *KeyValue)
+ ids[i], err = ss.RegisterStorageChangeChannel(chs[i])
+ require.NoError(t, err)
+ }
+
+ key1 := []byte("key1")
+ ss.SetStorage(key1, []byte("value1"))
+
+ var wg sync.WaitGroup
+ wg.Add(num)
+
+ for i, ch := range chs {
+
+ go func(i int, ch chan *KeyValue) {
+ select {
+ case c := <-ch:
+ require.Equal(t, key1, c.Key)
+ wg.Done()
+ case <-time.After(testMessageTimeout):
+ t.Error("did not receive storage change: ch=", i)
+ }
+ }(i, ch)
+
+ }
+
+ wg.Wait()
+
+ for _, id := range ids {
+ ss.UnregisterStorageChangeChannel(id)
+ }
+}